From 672d6f8e26712fdcbec85d2e81666d68c4ed9476 Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:20:31 +0100 Subject: [PATCH 1/8] Update server Did not stress-test properly, but basic functionality is there. --- examples/server/server.cpp | 4790 +++++++++++++++++++++--------------- 1 file changed, 2766 insertions(+), 2024 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fa792f9ec..81a8006f1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,2024 +1,2766 @@ -#pragma warning(disable : 4996) -#include "server-context.h" -#include "server-common.h" -#include "chat.h" - -#include "common.h" -#include "speculative.h" -#include "mtmd.h" -#include "sampling.h" -#include "llama.h" -#include "llama-vocab.h" -#include - - -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - - -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif - -#include -#include "index.html.gz.hpp" -#include "index_llamacpp.html.gz.hpp" -#include "loading.html.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef SQLITE3_MODERN_CPP_SUPPORT -#include - -struct DatabaseHandle { - sqlite::database db; - - DatabaseHandle(const std::string& path) : db(path) { - db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; - db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; - db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; - } -}; -#endif - -using json = nlohmann::ordered_json; -namespace fs = std::filesystem; -constexpr int HTTP_POLLING_SECONDS = 1; - -bool server_verbose = false; -bool server_log_json = true; - - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed -}; - - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - - -inline std::string get_model_name(std::string path) -{ - std::string filename = path.substr(path.find_last_of("/\\") + 1); - return filename; -}; - - -static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) { - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}} }) - : json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}} }); - - std::time_t t = std::time(0); - - json res = json{ - {"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}, - {"id", completion_id} - }; - - if (server_verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(server_task_result task_result, const std::string& completion_id) { - json result = task_result.data; - std::cout << result.dump(4) << std::endl; - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({ result }); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}} }); - } - else { - if (first) { - if (content.empty()) { - choices = json::array({ json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}} }); - } - else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{ {"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} }; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} }; - - return std::vector({ initial_ret, second_ret }); - } - } - else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({ json::object() }); - } - - choices = json::array({ json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - } }); - } - } - - json ret = json{ - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; - - if (task_result.timings.prompt_n != -1) { - ret.push_back({ "timings", task_result.timings.to_json() }); - } - - // - if (!finish_reason.empty()) { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({ "usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - } }); - } - - return std::vector({ ret }); -} - - -static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto& elem : embeddings) { - json embedding_obj; - - if (use_base64) { - const auto& vec = json_value(elem, "embedding", json::array()).get>(); - const char* data_ptr = reinterpret_cast(vec.data()); - size_t data_size = vec.size() * sizeof(float); - embedding_obj = { - {"embedding", base64::encode(data_ptr, data_size)}, - {"index", i++}, - {"object", "embedding"}, - {"encoding_format", "base64"} - }; - } - else { - embedding_obj = { - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - } - data.push_back(embedding_obj); - n_tokens += json_value(elem, "tokens_evaluated", 0); - } - json res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json { - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"data", data} - }; - - return res; -} - -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health" || req.path == "/v1/completions") { - return; - } - - LOG_INFO("request", { - {"remote_addr", req.remote_addr}, - {"remote_port", req.remote_port}, - {"status", res.status}, - {"method", req.method}, - {"path", req.path}, - {"params", req.params}, - }); - - LOG_VERBOSE("request", { - {"request", req.body}, - {"response", res.body}, - }); -} - -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context& ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector&& tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function& should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } - else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function& should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->error) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto& id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } - else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; - -auto res_err = [](httplib::Response& res, json error_data) { - json final_response{ {"error", error_data} }; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); -}; - -auto res_ok = [](httplib::Response& res, const json& data) { - res.set_content(data.dump(), "application/json; charset=utf-8"); - res.status = 200; -}; - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); - } - - shutdown_handler(signal); -} - -int main(int argc, char ** argv) { -#if SERVER_VERBOSE != 1 - log_disable(); -#endif - // own arguments required by this example - gpt_params params; - - if (!gpt_params_parse(argc, argv, params)) { - gpt_params_print_usage(argc, argv, params); - return 1; - } - - // parse arguments from environment variables - gpt_params_parse_from_env(params); - - // TODO: not great to use extern vars - server_log_json = params.log_json; - server_verbose = params.verbosity > 0; - - - // struct that contains llama context and inference - server_context ctx_server; - - if (!params.system_prompt.empty()) { - ctx_server.system_prompt_set(params.system_prompt); - } - - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INFO("build info", { - {"build", LLAMA_BUILD_NUMBER}, - {"commit", LLAMA_COMMIT} - }); - - LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); - - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}}); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INFO("Running without SSL", {}); - svr.reset(new httplib::Server()); - } -#else - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "ik_llama.cpp"}}); - - svr->set_logger(log_server_request); - - - - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { - std::string message; - try { - std::rethrow_exception(std::move(ep)); - } catch (std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_VERBOSE("Got exception", formatted_error); - res_err(res, formatted_error); - }); - - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_err() - }); - - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); - - if (!svr->bind_to_port(params.hostname, params.port)) { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); - return 1; - } - - std::unordered_map log_data; - - log_data["hostname"] = params.hostname; - log_data["port"] = std::to_string(params.port); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); - } else if (params.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; - } - - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - ctx_server.cache_ram_n_min = params.cache_ram_n_min; - ctx_server.cache_ram_similarity = params.cache_ram_similarity; -#ifdef SQLITE3_MODERN_CPP_SUPPORT - auto db_handle = std::make_shared(params.sql_save_file); - bool sqlite_extension_loaded = false; - if (!params.sqlite_zstd_ext_file.empty()) { - auto* conn = db_handle->db.connection().get(); - sqlite3_enable_load_extension(conn, 1); - char* errmsg = nullptr; - const int rc = sqlite3_load_extension( - conn, - params.sqlite_zstd_ext_file.c_str(), - nullptr, - &errmsg - ); - if(rc != SQLITE_OK) { - const std::string err = errmsg ? errmsg : "Unknown extension error"; - sqlite3_free(errmsg); - LOG_WARNING("Failed to load extension", {{"err", err}}); - } - else { - sqlite_extension_loaded = true; - } - sqlite3_enable_load_extension(conn, 0); - } -#else - auto db_handle = false; -#endif - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // print sample chat example to make it clear which template is used - - LOG_INFO("chat template", { - {"chat_template", common_chat_templates_source(ctx_server.chat_templates.get())}, - }); - - LOG_INFO("chat template", { - {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, {}).c_str() - }, - {"built_in", params.chat_template.empty()}, - }); - // - // Middlewares - // - - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (params.api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { - return true; // API key is valid - } - } - - auth_header = req.get_header_value("X-Api-Key"); - - if (std::find(params.api_keys.begin(), params.api_keys.end(), auth_header) != params.api_keys.end()) { - return true; // API key is valid - } - - // API key is invalid or not provided - res.status = 401; - res.set_content( - (json { - {"error", { - {"message", "Invalid API Key"}, - {"type", "authentication_error"}, - {"code", 401} - }} - }).dump(-1, ' ', false, json::error_handler_t::replace), - "application/json; charset=utf-8" - ); - LOG_WARNING("Unauthorized: Invalid API Key\n", {}); - return false; - }; - - auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) { - server_state current_state = state.load(); - if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); - if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); - res.status = 503; - } - else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; - } - else { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } - return false; - } - return true; - }; - - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - // - // Route handlers (or controllers) - // - - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data.at("idle"); - const int n_processing_slots = result.data.at("processing"); - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (params.endpoint_slots && req.has_param("include_slots")) { - health["slots"] = result.data.at("slots"); - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } break; - case SERVER_STATE_ERROR: - { - res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); - } break; - } - }; - - const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { - if (!params.endpoint_slots) { - res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - res.set_content(result.data.at("slots").dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { - if (!params.endpoint_metrics) { - res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - json data = result.data; - - const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); - const uint64_t t_prompt_processing = data.at("t_prompt_processing"); - - const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); - const uint64_t t_tokens_generation = data.at("t_tokens_generation"); - - const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data.at("n_tokens_predicted_total")} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", (uint64_t) data.at("kv_cache_tokens_count")} - },{ - {"name", "requests_processing"}, - {"help", "Number of request processing."}, - {"value", (uint64_t) data.at("processing")} - },{ - {"name", "requests_deferred"}, - {"help", "Number of request deferred."}, - {"value", (uint64_t) data.at("deferred")} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - const int64_t t_start = data.at("t_start"); - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK - }; - - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_SAVE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath } - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_RESTORE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath } - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_ERASE; - task.data = { - { "id_slot", id_slot }, - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - std::string id_slot_str = req.path_params.at("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - std::string action = req.get_param_value("action"); - - if (action == "save") { - handle_slots_save(req, res, id_slot); - } else if (action == "restore") { - handle_slots_restore(req, res, id_slot); - } else if (action == "erase") { - handle_slots_erase(req, res, id_slot); - } else { - res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - } - }; - - const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - std::string template_key = "tokenizer.chat_template", curr_tmpl; - int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); - } - } - json data = { - { "system_prompt", ctx_server.system_prompt.c_str() }, - { "model_alias", ctx_server.params_base.model_alias }, - { "model_path", ctx_server.params_base.model}, - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_name", get_model_name(ctx_server.params_base.model)}, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), /* special= */ true)}, - { "model_path", ctx_server.params_base.model }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "n_ctx", ctx_server.n_ctx } - - }; - - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_props_simple = [&ctx_server](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - int n_past = 0; - int slot_id = 0; - for (server_slot& slot : ctx_server.slots) { - if (slot.n_past > n_past) { - n_past = slot.n_past; - slot_id = slot.id; - } - } - json data = { - { "model_name", get_model_name(ctx_server.params_base.model)}, - { "model_path", ctx_server.params_base.model }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "n_ctx", ctx_server.n_ctx } - }; - res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - - // handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server, ¶ms]( - server_task_type type, - json& data, - const std::vector& files, - const std::function& is_connection_closed, - httplib::Response& res, - oaicompat_type oaicompat) -> void { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - const auto completion_id = gen_chatcmplid(); - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); - - try { - std::vector tasks; - - const auto& prompt = data.at("prompt"); - - // process prompt - std::vector inputs; - - if (oaicompat && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } - else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.data = data; - //task.params = server_task::params_from_json_cmpl( - // ctx_server.ctx, - // ctx_server.params, - // data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); - tasks.push_back(std::move(task)); - } - - rd->post_tasks(std::move(tasks)); - } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } - bool stream = json_value(data, "stream", false); - if (!stream) { - // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); - if (all_results.is_terminated) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } - else { - json arr = json::array(); - for (auto& res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); - } - } - else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); - if (first_result == nullptr) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; - } - else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - // next responses are streamed - json first_result_json = first_result->to_json(); - const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink& sink) mutable -> bool { - const auto sse = [oaicompat, &sink](const json& res) { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, res); - } - else { - return server_sent_event(sink, res); - } - }; - // flush the first result as it's not an error - if (!first_result_json.empty()) { - if (!sse(first_result_json)) { - sink.done(); - return false; // sending failed, go to on_complete() - } - first_result_json.clear(); // mark as sent - } - - // receive subsequent results - auto result = rd->next([&sink] { return !sink.is_writable(); }); - if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() - } - - // send the results - json res_json = result->to_json(); - bool ok = false; - if (result->is_error()) { - ok = sse(json{ { "error", result->to_json() } }); - sink.done(); - return false; // go to on_complete() - } - else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - ok = sse(res_json); - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() - } - - // has next data, continue - return true; - }; - - auto on_complete = [rd](bool) { - rd->stop(); - }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - auto data = json::parse(req.body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); - }; - - const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request& req, httplib::Response& res) { - auto body = json::parse(req.body); - json data = oaicompat_chat_params_parse(body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_COMPLETION); - }; - - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }; - - - - const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); - std::vector files; - json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files); - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_CHAT); - }; - - const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = anthropic_params_from_json( - ctx_server.model, - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_ANTHROPIC); - }; - - const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - std::vector files; - json body = json::parse(req.body); - - // Parse the Anthropic request (max_tokens is not required for count_tokens) - json body_parsed = anthropic_params_from_json( - ctx_server.model, - body, - ctx_server.oai_parser_opt, - files); - - json prompt = body_parsed.at("prompt"); - llama_tokens tokens = tokenize_mixed(llama_get_vocab(ctx_server.ctx), prompt, true, true); - - res_ok(res, {{"input_tokens", static_cast(tokens.size())}}); - return res; - }; - - // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - auto body = json::parse(req.body); - std::vector files; // dummy, unused - json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files); - res_ok(res, { { "prompt", std::move(data.at("prompt")) } }); - }; - - const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - const int id_task = ctx_server.queue_tasks.get_new_id(); - server_tokens token; // dummy tokens - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, data, true, false, std::move(token)); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::vector tokens; - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - tokens = ctx_server.tokenize(body.at("content"), add_special); - } - const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const std::vector tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); - } - - const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) { - if (!ctx_server.params_base.embedding) { - res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } - else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } - else { - res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } - else if (format != "float") { - res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - auto vocab = llama_get_vocab(ctx_server.ctx); - auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true); - for (const auto& tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.embd_normalize = embd_normalize; - task.embedding = true; // probably not needed - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); - - // collect results - if (all_results.is_terminated) { - llama_decode_stop(); - return; // connection is closed - } - else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } - else { - for (auto& res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res_ok(res, root); - - }; - - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); - }; - - const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); - }; - - - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - json result = json::array(); - for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { - auto & la = ctx_server.lora_adapters[i]; - result.push_back({ - {"id", i}, - {"path", la.path}, - {"scale", la.scale}, - }); - } - res.set_content(result.dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.lora_adapters.size(); - - // clear existing value - for (auto & la : ctx_server.lora_adapters) { - la.scale = 0.0f; - } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.lora_adapters[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - server_task task; - task.type = SERVER_TASK_TYPE_SET_LORA; - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - res.set_content(result.data.dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - json response = json::array(); - - try { - for (const auto& entry : fs::directory_iterator(params.slot_save_path)) { - if (!entry.is_regular_file() || entry.file_size() < 12) { - continue; - } - - std::ifstream file(entry.path(), std::ios::binary); - if (!file) continue; - - uint32_t magic, version, n_token_count; - file.read(reinterpret_cast(&magic), sizeof(magic)); - file.read(reinterpret_cast(&version), sizeof(version)); - file.read(reinterpret_cast(&n_token_count), sizeof(n_token_count)); - - if (magic != LLAMA_STATE_SEQ_MAGIC || - version != LLAMA_STATE_SEQ_VERSION || - entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) { - continue; - } - - std::vector tokens(n_token_count); - file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); - - //C++17 is not modern enough to have a nice and portable way to get the mtime of a file - //so the following seems to be needed - auto ftime = fs::last_write_time(entry.path()); - auto system_time = std::chrono::time_point_cast( - ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now() - ); - std::time_t c_time = std::chrono::system_clock::to_time_t(system_time); - std::tm tm_struct; - #if defined(_WIN32) - localtime_s(&tm_struct, &c_time); - #else - localtime_r(&c_time, &tm_struct); - #endif - std::ostringstream oss; - oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S"); - auto str_time = oss.str(); - - - response.push_back({ - {"filename", entry.path().filename().string()}, - {"filesize", entry.file_size()}, - {"mtime", str_time}, - {"token_count", n_token_count}, - {"prompt", tokens_to_str(ctx_server.ctx, tokens)} - }); - } - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - json response = json::array(); - for (server_slot & slot : ctx_server.slots) { - response.push_back({ - {"slot_id", slot.id}, - {"token_count", slot.cache_tokens.size()}, - {"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) } - }); - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - - const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - json response; - namespace fs = std::filesystem; - - try { - const json body = json::parse(req.body); - const std::string filename_str = body.at("filename"); - - // prevent directory traversal attacks - if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) { - res.status = 400; - response = {{"error", "Invalid filename format."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str); - - if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) { - res.status = 404; - response = {{"error", "File not found."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - if (fs::remove(file_to_delete)) { - response = { - {"status", "deleted"}, - {"filename", filename_str} - }; - } else { - res.status = 500; - response = {{"error", "Failed to delete the file."}}; - } - } catch (const json::parse_error& e) { - res.status = 400; - response = {{"error", "Invalid JSON request body."}}; - } catch (const json::out_of_range& e) { - res.status = 400; - response = {{"error", "Missing 'filename' key in request body."}}; - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - json response; - namespace fs = std::filesystem; - - try { - const json body = json::parse(req.body); - const std::string old_filename_str = body.at("old_filename"); - const std::string new_filename_str = body.at("new_filename"); - - if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos || - new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) { - res.status = 400; - response = {{"error", "Invalid filename format."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str; - const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str; - - if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) { - res.status = 404; - response = {{"error", "Source file not found."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - if (fs::exists(new_path)) { - res.status = 409; - response = {{"error", "Destination filename already exists."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - std::error_code ec; - fs::rename(old_path, new_path, ec); - - if (ec) { - res.status = 500; - response = {{"error", "Failed to rename file: " + ec.message()}}; - } else { - response = { - {"status", "renamed"}, - {"old_filename", old_filename_str}, - {"new_filename", new_filename_str} - }; - } - - } catch (const json::parse_error& e) { - res.status = 400; - response = {{"error", "Invalid JSON request body."}}; - } catch (const json::out_of_range& e) { - res.status = 400; - response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}}; - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; -#ifdef SQLITE3_MODERN_CPP_SUPPORT - const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) { - res.set_content( - json{{"version", 4}, - {"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(), - "application/json" - ); - }; -#else - const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void { - res.set_content( - json{{"version", 4}, - {"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(), - "application/json" - ); - }; -#endif - -#ifdef SQLITE3_MODERN_CPP_SUPPORT - auto db_handler = [db_handle](auto func) { - return [func, db_handle](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", "*"); - try { - const json body = !req.body.empty() ? json::parse(req.body) : json::object(); - func(*db_handle, body, req, res); - } catch(const std::exception& e) { - res.status = 500; - res.set_content( - json{{"ok", false}, {"message", e.what()}}.dump(), - "application/json" - ); - } - }; - }; -#else - auto db_handler = [db_handle](auto func) { - return [func, db_handle](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", "*"); - res.status = 500; - res.set_content( - json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(), - "application/json" - ); - }; - }; -#endif - - const auto normalize_store_name = [](const std::string& storeName) { - if(storeName.empty()) return std::string("sessions"); - - std::string normalized; - normalized.reserve(storeName.size()); - - for(char c : storeName) { - if(std::isalpha(static_cast(c))) { - normalized.push_back(std::tolower(static_cast(c))); - } - } - - return normalized.empty() ? "sessions" : normalized; - }; - - const auto get_key_string = [](const json& j) { - return j.is_string() ? j.get() : j.dump(); - }; - - - const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - std::string data; - const std::string store = normalize_store_name(body["storeName"]); - db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data; - if(data.empty()) { - res.status = 404; - res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json"); - } else { - json response{{"ok", true}}; - response["result"] = (store == "names") ? json(data) : json::parse(data); - res.set_content(response.dump(), "application/json"); - } - }); - - const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - const std::string store = normalize_store_name(body["storeName"]); - const std::string data = (store == "names") ? body["data"].get() : body["data"].dump(); - db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data; - res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json"); - }); - - const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) { - db.db << "UPDATE names SET data = ? WHERE key = ?" - << body["newName"].get() - << get_key_string(body["key"]); - res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json"); - }); - - const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >> - [&](const std::string& key, const std::string& data) { - result[key] = json::parse(data); - }; - res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); - }); - - const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) { - result[key] = data; - }; - res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); - }); - - const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?" - << get_key_string(body["key"]); - res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json"); - }); - - const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "VACUUM"; - res.set_content(json{"ok", true}.dump(), "application/json"); - }); - - const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) { - result[id] = config; - }; - res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json"); - }); - - const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) { - std::string data; - if (body["duration"].is_null()) { - db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get() >> data; - } - else { - db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get() << body["db_load"].get() >> data; - } - json response{{"ok", true}}; - response["result"] = json::parse(data); - res.set_content(response.dump(), "application/json"); - }); - - const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) { - db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get() + "\",\"column\": \"" + body["column"].get() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}')"; - res.set_content(json{"ok", true}.dump(), "application/json"); - }); - - const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) { - std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}"; - db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')"; - res.set_content(json{{"ok", true}}.dump(), "application/json"); - }); - - // - // Router - // - if (params.webui == COMMON_WEBUI_NONE) { - LLAMA_LOG_INFO("Web UI is disabled\n"); - } - else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - svr->set_base_dir(params.public_path); - } - - { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point("/", params.public_path); - if (!is_found) { - GGML_ABORT("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } - } - else { - - // using embedded static index.html - svr->Get("/", [params](const httplib::Request& req, httplib::Response& res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } - else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - if (params.webui == COMMON_WEBUI_AUTO) { - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - else if (params.webui == COMMON_WEBUI_LLAMACPP) { - res.set_content(reinterpret_cast(index_llamacpp_html_gz), index_llamacpp_html_gz_len, "text/html; charset=utf-8"); - } - else { - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - } - return false; - }); - } - } - } - // register API routes - svr->Get ("/health", handle_health); - svr->Get ("/metrics", handle_metrics); - svr->Get ("/props", handle_props); - svr->Get("/v1/props", handle_props_simple); - svr->Get ("/v1/models", handle_models); - svr->Post("/completion", handle_completions); // legacy - svr->Post("/completions", handle_completions); // legacy - svr->Post("/v1/completions", handle_completions_oai); - svr->Post("/chat/completions", handle_chat_completions); - svr->Post("/v1/chat/completions", handle_chat_completions); - svr->Post("/v1/messages", handle_anthropic_messages); - svr->Post("/v1/messages/count_tokens", handle_anthropic_count_tokens); - svr->Post("/infill", handle_infill); - svr->Post("/embedding", handle_embeddings); // legacy - svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings_oai); - svr->Post("/tokenize", handle_tokenize); - svr->Post("/detokenize", handle_detokenize); - svr->Post("/apply-template", handle_apply_template); - // LoRA adapters hotswap - svr->Get ("/lora-adapters", handle_lora_adapters_list); - svr->Post("/lora-adapters", handle_lora_adapters_apply); - // Save & load slots - svr->Get ("/slots", handle_slots); - svr->Get ("/slots/list", list_slot_prompts); - if (!params.slot_save_path.empty()) { - // these endpoints rely on slot_save_path existing - svr->Post("/slots/:id_slot", handle_slots_action); - svr->Get ("/list", list_saved_prompts); - svr->Post("/delete_prompt", delete_saved_prompt); - svr->Post("/rename_prompt", rename_saved_prompt); - - } - - svr->Get ("/version", handle_version); - if (!params.sql_save_file.empty()) { - // these endpoints rely on sql_save_file existing - svr->Post("/load", handle_load); - svr->Post("/save", handle_save); - svr->Post("/rename", handle_rename); - svr->Post("/all", handle_all); - svr->Post("/sessions", handle_sessions); - svr->Get ("/sessions", handle_sessions); - svr->Post("/delete", handle_delete); - //VACUUM is there for the extension but does not require the extension - svr->Get ("/vacuum", handle_vacuum); -#ifdef SQLITE3_MODERN_CPP_SUPPORT - if (sqlite_extension_loaded) { - svr->Get ("/zstd_get_configs", handle_zstd_get_configs); - svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance); - svr->Post("/zstd_enable_transparent", handle_zstd_enable); - svr->Post("/zstd_update_transparent", handle_zstd_config_update); - } -#endif - } - // - // Start the server - // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - - LOG_INFO("HTTP server listening", log_data); - - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; - } - - return 0; - }); - - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); - - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - - ctx_server.queue_tasks.start_loop(); - - svr->stop(); - t.join(); - - llama_backend_free(); - - return 0; -} +#pragma warning(disable : 4996) +#include "server-context.h" +#include "server-common.h" +#include "chat.h" + +#include "common.h" +#include "speculative.h" +#include "mtmd.h" +#include "sampling.h" +#include "llama.h" +#include "llama-vocab.h" +#include +#include +#include + +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" + + +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif + +#include +#include "index.html.gz.hpp" +#include "index_llamacpp.html.gz.hpp" +#include "loading.html.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef SQLITE3_MODERN_CPP_SUPPORT +#include + +struct DatabaseHandle { + sqlite::database db; + + DatabaseHandle(const std::string& path) : db(path) { + db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; + } +}; +#endif + +using json = nlohmann::ordered_json; +namespace fs = std::filesystem; +constexpr int HTTP_POLLING_SECONDS = 1; + +bool server_verbose = false; +bool server_log_json = true; + + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + + +static inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + + +inline std::string get_model_name(std::string path) +{ + std::string filename = path.substr(path.find_last_of("/\\") + 1); + return filename; +}; + + +static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) { + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason = "length"; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + + json choices = + streaming ? json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}} }) + : json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}} }); + + std::time_t t = std::time(0); + + json res = json{ + {"choices", choices}, + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + }}, + {"id", completion_id} + }; + + if (server_verbose) { + res["__verbose"] = result; + } + + if (result.contains("completion_probabilities")) { + res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + } + + return res; +} + +// return value is vector as there is one case where we might need to generate two responses +static std::vector format_partial_response_oaicompat(server_task_result task_result, const std::string& completion_id) { + json result = task_result.data; + std::cout << result.dump(4) << std::endl; + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + return std::vector({ result }); + } + + bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; + std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + if (stopped_limit) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}} }); + } + else { + if (first) { + if (content.empty()) { + choices = json::array({ json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}} }); + } + else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{ {"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} }; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} }; + + return std::vector({ initial_ret, second_ret }); + } + } + else { + // Some idiosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({ json::object() }); + } + + choices = json::array({ json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + } }); + } + } + + json ret = json{ + {"choices", choices}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + + if (task_result.timings.prompt_n != -1) { + ret.push_back({ "timings", task_result.timings.to_json() }); + } + + // + if (!finish_reason.empty()) { + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + ret.push_back({ "usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + } }); + } + + return std::vector({ ret }); +} + + +static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto& elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } + else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + json res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + + return res; +} + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health" || req.path == "/v1/completions") { + return; + } + + LOG_INFO("request", { + {"remote_addr", req.remote_addr}, + {"remote_port", req.remote_port}, + {"status", res.status}, + {"method", req.method}, + {"path", req.path}, + {"params", req.params}, + }); + + LOG_VERBOSE("request", { + {"request", req.body}, + {"response", res.body}, + }); +} + +// generator-like API for server responses, support pooling connection state and aggregating results +struct server_response_reader { + std::unordered_set id_tasks; + server_context& ctx_server; + size_t received_count = 0; + bool cancelled = false; + + server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector&& tasks) { + id_tasks = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + bool has_next() { + return !cancelled && received_count < id_tasks.size(); + } + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function& should_stop) { + while (true) { + server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } + else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here + } + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + + batch_response wait_for_all(const std::function& should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->error) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; + } + + void stop() { + ctx_server.queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto& id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + ctx_server.queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + ctx_server.queue_tasks.post(std::move(cancel_tasks), true); + } + else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } + } +}; + +auto res_err = [](httplib::Response& res, json error_data) { + json final_response{ {"error", error_data} }; + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); +}; + +auto res_ok = [](httplib::Response& res, const json& data) { + res.set_content(data.dump(), "application/json; charset=utf-8"); + res.status = 200; +}; + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +int main(int argc, char ** argv) { +#if SERVER_VERBOSE != 1 + log_disable(); +#endif + // own arguments required by this example + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + gpt_params_print_usage(argc, argv, params); + return 1; + } + + // parse arguments from environment variables + gpt_params_parse_from_env(params); + + // TODO: not great to use extern vars + server_log_json = params.log_json; + server_verbose = params.verbosity > 0; + + + // struct that contains llama context and inference + server_context ctx_server; + + if (!params.system_prompt.empty()) { + ctx_server.system_prompt_set(params.system_prompt); + } + + if (params.model_alias == "unknown") { + params.model_alias = params.model; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INFO("build info", { + {"build", LLAMA_BUILD_NUMBER}, + {"commit", LLAMA_COMMIT} + }); + + LOG_INFO("system info", { + {"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); + + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}}); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INFO("Running without SSL", {}); + svr.reset(new httplib::Server()); + } +#else + svr.reset(new httplib::Server()); +#endif + + std::atomic state{SERVER_STATE_LOADING_MODEL}; + + svr->set_default_headers({{"Server", "ik_llama.cpp"}}); + + svr->set_logger(log_server_request); + + + + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + std::string message; + try { + std::rethrow_exception(std::move(ep)); + } catch (std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_VERBOSE("Got exception", formatted_error); + res_err(res, formatted_error); + }); + + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + } + // for other error codes, we skip processing here because it's already done by res_err() + }); + + // set timeouts and change hostname and port + svr->set_read_timeout (params.timeout_read); + svr->set_write_timeout(params.timeout_write); + + if (!svr->bind_to_port(params.hostname, params.port)) { + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); + return 1; + } + + std::unordered_map log_data; + + log_data["hostname"] = params.hostname; + log_data["port"] = std::to_string(params.port); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); + } else if (params.api_keys.size() > 1) { + log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; + } + + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + ctx_server.cache_ram_n_min = params.cache_ram_n_min; + ctx_server.cache_ram_similarity = params.cache_ram_similarity; +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handle = std::make_shared(params.sql_save_file); + bool sqlite_extension_loaded = false; + if (!params.sqlite_zstd_ext_file.empty()) { + auto* conn = db_handle->db.connection().get(); + sqlite3_enable_load_extension(conn, 1); + char* errmsg = nullptr; + const int rc = sqlite3_load_extension( + conn, + params.sqlite_zstd_ext_file.c_str(), + nullptr, + &errmsg + ); + if(rc != SQLITE_OK) { + const std::string err = errmsg ? errmsg : "Unknown extension error"; + sqlite3_free(errmsg); + LOG_WARNING("Failed to load extension", {{"err", err}}); + } + else { + sqlite_extension_loaded = true; + } + sqlite3_enable_load_extension(conn, 0); + } +#else + auto db_handle = false; +#endif + // load the model + if (!ctx_server.load_model(params)) { + state.store(SERVER_STATE_ERROR); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + } + + LOG_INFO("model loaded", {}); + + const auto model_meta = ctx_server.model_meta(); + + // print sample chat example to make it clear which template is used + + LOG_INFO("chat template", { + {"chat_template", common_chat_templates_source(ctx_server.chat_templates.get())}, + }); + + LOG_INFO("chat template", { + {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, {}).c_str() + }, + {"built_in", params.chat_template.empty()}, + }); + // + // Middlewares + // + + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/v1/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (params.api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { + return true; // API key is valid + } + } + + auth_header = req.get_header_value("X-Api-Key"); + + if (std::find(params.api_keys.begin(), params.api_keys.end(), auth_header) != params.api_keys.end()) { + return true; // API key is valid + } + + // API key is invalid or not provided + res.status = 401; + res.set_content( + (json { + {"error", { + {"message", "Invalid API Key"}, + {"type", "authentication_error"}, + {"code", 401} + }} + }).dump(-1, ' ', false, json::error_handler_t::replace), + "application/json; charset=utf-8" + ); + LOG_WARNING("Unauthorized: Invalid API Key\n", {}); + return false; + }; + + auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } + else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } + else { + res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } + return false; + } + return true; + }; + + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + // + // Route handlers (or controllers) + // + + const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; + task.id_target = -1; + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + const int n_idle_slots = result.data.at("idle"); + const int n_processing_slots = result.data.at("processing"); + + json health = { + {"status", "ok"}, + {"slots_idle", n_idle_slots}, + {"slots_processing", n_processing_slots} + }; + + res.status = 200; // HTTP OK + if (params.endpoint_slots && req.has_param("include_slots")) { + health["slots"] = result.data.at("slots"); + } + + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; + } + case SERVER_STATE_LOADING_MODEL: + { + res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } break; + case SERVER_STATE_ERROR: + { + res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); + } break; + } + }; + + const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + if (!params.endpoint_slots) { + res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + res.set_content(result.data.at("slots").dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + if (!params.endpoint_metrics) { + res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + task.data.push_back({{"reset_bucket", true}}); + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + json data = result.data; + + const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); + const uint64_t t_prompt_processing = data.at("t_prompt_processing"); + + const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); + const uint64_t t_tokens_generation = data.at("t_tokens_generation"); + + const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) data.at("n_tokens_predicted_total")} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + },{ + {"name", "kv_cache_usage_ratio"}, + {"help", "KV-cache usage. 1 means 100 percent usage."}, + {"value", 1. * kv_cache_used_cells / params.n_ctx} + },{ + {"name", "kv_cache_tokens"}, + {"help", "KV-cache tokens."}, + {"value", (uint64_t) data.at("kv_cache_tokens_count")} + },{ + {"name", "requests_processing"}, + {"help", "Number of request processing."}, + {"value", (uint64_t) data.at("processing")} + },{ + {"name", "requests_deferred"}, + {"help", "Number of request deferred."}, + {"value", (uint64_t) data.at("deferred")} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + const int64_t t_start = data.at("t_start"); + res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + + res.set_content(prometheus.str(), "text/plain; version=0.0.4"); + res.status = 200; // HTTP OK + }; + + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + std::string template_key = "tokenizer.chat_template", curr_tmpl; + int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); + } + } + json data = { + { "system_prompt", ctx_server.system_prompt.c_str() }, + { "model_alias", ctx_server.params_base.model_alias }, + { "model_path", ctx_server.params_base.model}, + { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_name", get_model_name(ctx_server.params_base.model)}, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), /* special= */ true)}, + { "model_path", ctx_server.params_base.model }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "n_ctx", ctx_server.n_ctx } + + }; + + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_props_simple = [&ctx_server](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + int n_past = 0; + int slot_id = 0; + for (server_slot& slot : ctx_server.slots) { + if (slot.n_past > n_past) { + n_past = slot.n_past; + slot_id = slot.id; + } + } + json data = { + { "model_name", get_model_name(ctx_server.params_base.model)}, + { "model_path", ctx_server.params_base.model }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "n_ctx", ctx_server.n_ctx } + }; + res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + + + +// handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_impl = [&ctx_server, ¶ms]( + server_task_type type, + json& data, + const std::vector& files, + const std::function& is_connection_closed, + httplib::Response& res, + oaicompat_type oaicompat) -> void { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + // ---------------------------------------------------------------- + // 1. Regex Validation + // ---------------------------------------------------------------- + auto validate_regex_list = [&](const std::string& field_name) -> std::string { + if (data.contains(field_name) && data[field_name].is_array()) { + for (const auto& val : data[field_name]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s); + } catch (const std::regex_error& e) { + return s; + } + } + } + } + } + return ""; + }; + + std::string invalid_re = validate_regex_list("banned_regex"); + if (invalid_re.empty()) invalid_re = validate_regex_list("banned_regex_case_insensitive"); + + if (!invalid_re.empty()) { + res_err(res, format_error_response("Invalid regex: " + invalid_re, ERROR_TYPE_INVALID_REQUEST)); + return; + } + + const auto completion_id = gen_chatcmplid(); + + // Process prompt / inputs + std::vector inputs; + try { + const auto& prompt = data.at("prompt"); + if (oaicompat && ctx_server.mctx != nullptr) { + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } + else { + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + } + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // ---------------------------------------------------------------- + // Check if we need the complex "Banned String" logic + // Only enable if the lists are present AND contain actual strings. + // ---------------------------------------------------------------- + auto list_has_content = [&](const std::string& key) { + if (data.contains(key) && data[key].is_array()) { + for (const auto& item : data[key]) { + if (item.is_string() && !item.get().empty()) { + return true; + } + } + } + return false; + }; + + bool has_banned_content = list_has_content("banned_strings") || + list_has_content("banned_regex") || + list_has_content("banned_regex_case_insensitive"); + + if (!has_banned_content) { + // ---------------------------------------------------------------- + // PATH A: Standard Logic (server_response_reader) + // ---------------------------------------------------------------- + + // need to store the reader as a pointer, so that it won't be destroyed when the handle returns + // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() + const auto rd = std::make_shared(ctx_server); + + try { + std::vector tasks; + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.data = data; + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); + tasks.push_back(std::move(task)); + } + + rd->post_tasks(std::move(tasks)); + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + bool stream = json_value(data, "stream", false); + if (!stream) { + // non-stream, wait for the results + auto all_results = rd->wait_for_all(is_connection_closed); + if (all_results.is_terminated) { + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed + } + else if (all_results.error) { + res_err(res, all_results.error->to_json()); + return; + } + else { + json arr = json::array(); + for (auto& res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + if (oaicompat) { + arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); + } else { + arr.push_back(res->to_json()); + } + } + // if single request, return single object instead of array + res_ok(res, arr.size() == 1 ? arr[0] : arr); + } + } + else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd->next(is_connection_closed); + if (first_result == nullptr) { + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed + } + else if (first_result->is_error()) { + res_err(res, first_result->to_json()); + return; + } + else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // Prepare first result JSON (handling OAI format if needed) + std::vector first_result_parts; + if (oaicompat) { + first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); + } else { + first_result_parts.push_back(first_result->to_json()); + } + + const auto chunked_content_provider = [first_result_parts, rd, oaicompat, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { + const auto sse = [oaicompat, &sink](const json& res) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, res); + } + else { + return server_sent_event(sink, res); + } + }; + + // flush the first result parts + for (auto& part : first_result_parts) { + if (!part.empty()) { + if (!sse(part)) { + sink.done(); + return false; // sending failed, go to on_complete() + } + part.clear(); // mark as sent + } + } + + // receive subsequent results + auto result = rd->next([&sink] { return !sink.is_writable(); }); + if (result == nullptr) { + sink.done(); + return false; // connection is closed, go to on_complete() + } + + // send the results + bool ok = false; + if (result->is_error()) { + ok = sse(json{ { "error", result->to_json() } }); + sink.done(); + return false; // go to on_complete() + } + else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + + if (oaicompat) { + std::vector parts = format_partial_response_oaicompat(*result, completion_id); + for (const auto& part : parts) { + ok = sse(part); + if (!ok) break; + } + } else { + ok = sse(result->to_json()); + } + } + + if (!ok) { + sink.done(); + return false; // sending failed, go to on_complete() + } + + // check if there is more data + if (!rd->has_next()) { + if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } + sink.done(); + return false; // no more data, go to on_complete() + } + + // has next data, continue + return true; + }; + + auto on_complete = [rd](bool) { + rd->stop(); + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + + } else { + // ---------------------------------------------------------------- + // PATH B: Banned Content Logic (Slow Path with Buffering & Rewind) + // ---------------------------------------------------------------- + auto buffer_and_check_string_ban_and_rewind_logic = [&]() { + // Helper to mimic request_cancel using the task queue directly + auto request_cancel = [&ctx_server](int id_target) { + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_target; + std::vector tasks; + tasks.push_back(std::move(task)); + ctx_server.queue_tasks.post(std::move(tasks), true); + }; + + // Helper to post a completion task with correct OAI params + auto post_task_with_params = [&ctx_server, oaicompat, completion_id](int id_task, json& task_data, server_tokens& tokens) { + server_task task(SERVER_TASK_TYPE_COMPLETION); + task.id = id_task; + task.index = 0; + task.tokens = std::move(tokens); + task.data = task_data; + task.id_slot = json_value(task_data, "id_slot", -1); + + // Critical: Set OAI params so worker generates correct output format + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); + + std::vector tasks; + tasks.push_back(std::move(task)); + ctx_server.queue_tasks.post(std::move(tasks)); + }; + + const int id_task = ctx_server.queue_tasks.get_new_id(); + ctx_server.queue_results.add_waiting_task_id(id_task); + + // Use helper instead of request_completion + post_task_with_params(id_task, data, inputs[0]); + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // Non-streaming: wait for result (using pointer to avoid slicing) + std::unordered_set ids = { id_task }; + server_task_result_ptr result = nullptr; + + // Simple blocking wait + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && is_connection_closed()) { + request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + return; + } + } + + if (!result->is_error()) { + json result_json; + if (oaicompat) { + result_json = format_final_response_oaicompat(data, result->data, completion_id, false); + } else { + result_json = result->to_json(); + } + res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } + else { + res_err(res, result->to_json()); + } + ctx_server.queue_results.remove_waiting_task_id(id_task); + } + else { + // Shared state to track the currently running task ID across retries. + auto active_task_id = std::make_shared(id_task); + + // Capture 'data' by value to use as a template for retries + const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, send_done = params.send_done, data, request_cancel, post_task_with_params](size_t, httplib::DataSink& sink) mutable { + // Define sse here so it's visible to both try and catch blocks + const auto sse = [oaicompat, &sink](const json &res) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, res); + } else { + return server_sent_event(sink, res); + } + }; + + try { + bool successful_completion = false; + + // 1. Parse Configuration from Request + + // Banned Strings + std::vector stop_phrases; + if (data.contains("banned_strings") && data["banned_strings"].is_array()) { + for (const auto& val : data["banned_strings"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) stop_phrases.push_back(s); + } + } + } + + // Sort banned strings by length (descending) + std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { + return a.length() > b.length(); + }); + + // Banned Regex (Case Sensitive & Insensitive) + std::vector regex_patterns; // For buffer size calculation + std::vector stop_regexes; // Compiled regexes + + auto add_regex_list = [&](const std::string& field_name, bool case_insensitive) { + if (data.contains(field_name) && data[field_name].is_array()) { + for (const auto& val : data[field_name]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + auto flags = std::regex_constants::ECMAScript; + if (case_insensitive) flags |= std::regex_constants::icase; + stop_regexes.emplace_back(s, flags); + regex_patterns.push_back(s); + } + } + } + } + }; + + // We assume validation passed in handle_completions_impl, so no try-catch needed here + add_regex_list("banned_regex", false); + add_regex_list("banned_regex_case_insensitive", true); + + // Logit Bias Penalty (Default: -10000.0) + float ban_bias = -10000.0f; + if (data.contains("banned_bias") && data["banned_bias"].is_number()) { + ban_bias = data["banned_bias"].get(); + } + + // Manual Buffer Size + size_t manual_buffer_size = 0; + if (data.contains("banbuffer_size") && data["banbuffer_size"].is_number_unsigned()) { + manual_buffer_size = data["banbuffer_size"].get(); + } + + // Token Limit Tracking + int original_n_predict = -1; + if (data.contains("n_predict") && data["n_predict"].is_number_integer()) { + original_n_predict = data["n_predict"].get(); + } + int total_tokens_streamed = 0; + + // ============================================================ + // FAST PATH: No banned strings AND No regex -> No buffering + // ============================================================ + if (stop_phrases.empty() && stop_regexes.empty()) { + while (true) { + std::unordered_set ids = { *active_task_id }; + server_task_result_ptr result = nullptr; + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && !sink.is_writable()) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + + if (!result->is_error()) { + // Use format_partial_response_oaicompat to get the correct chunks + std::vector parts; + if (oaicompat) { + parts = format_partial_response_oaicompat(*result, completion_id); + } else { + parts.push_back(result->data); + } + + for (const auto& item : parts) { + if (!sse(item)) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + + if (result->is_stop()) { + successful_completion = true; + break; + } + } else { + sse(result->to_json()); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + } + // ============================================================ + // SLOW PATH: Buffering and Banning Logic + // ============================================================ + else { + // Calculate Buffer Size + size_t BUFFER_SIZE; + if (manual_buffer_size > 0) { + BUFFER_SIZE = manual_buffer_size; + } else { + size_t max_len = 0; + // Check strings + if (!stop_phrases.empty()) { + max_len = stop_phrases[0].length(); // First is longest due to sort + } + // Check regex patterns + for (const auto& pat : regex_patterns) { + if (pat.length() > max_len) max_len = pat.length(); + } + + // Default: Longest string/regex + 1 + BUFFER_SIZE = std::max((size_t)1, max_len + 1); + } + + // Initialize Buffer & State + std::deque token_buffer; + + int current_task_id = id_task; + + // Track bans specifically for the current "next token" to be generated. + std::set current_step_bans; + int ban_slot_index = -1; + + // Track the text that has been confirmed/sent to the client. + std::string current_prompt_str = ""; + if (data.contains("prompt") && data["prompt"].is_string()) { + current_prompt_str = data["prompt"].get(); + } + + // Helper to extract text content + auto get_content_str = [](const json& j) -> std::string { + if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) { + const auto& choice = j["choices"][0]; + if (choice.contains("delta") && choice["delta"].contains("content")) { + auto val = choice["delta"]["content"]; + if (val.is_string()) return val.get(); + } + } + if (j.contains("content")) { + auto val = j["content"]; + if (val.is_string()) return val.get(); + } + return ""; + }; + + // Helper to extract Token ID + auto get_token_id = [](const json& j) -> int { + if (j.contains("__raw_token_id")) return j["__raw_token_id"].get(); + if (j.contains("token")) return j["token"].get(); + if (j.contains("id")) return j["id"].get(); + return -1; + }; + + // Helper for case-insensitive search + auto to_lower_str = [](std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c){ return std::tolower(c); }); + return s; + }; + + // Helper to print buffer + auto print_debug_buffer = [&](const std::deque& buf) { + std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; + size_t print_len = std::max(buf.size(), BUFFER_SIZE); + for (size_t i = 0; i < print_len; ++i) { + if (i < buf.size()) { + std::string content = get_content_str(buf[i]); + std::string escaped; + for (char c : content) { + if (c == '\n') escaped += "\\n"; + else if (c == '"') escaped += "\\\""; + else escaped += c; + } + std::cout << "\"" << escaped << "\""; + } else { + std::cout << "\"\""; + } + if (i < print_len - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + }; + + while (true) { + // Ensure shared state matches current local state + *active_task_id = current_task_id; + + // 0. Check connection status explicitly + if (!sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + // Receive from the CURRENT task ID using pointer to avoid slicing + std::unordered_set ids = { current_task_id }; + server_task_result_ptr result = nullptr; + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && !sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + } + + std::vector items_to_buffer; + + if (!result->is_error()) { + // Use format_partial_response_oaicompat to get the correct chunks + std::vector parts; + if (oaicompat) { + parts = format_partial_response_oaicompat(*result, completion_id); + } else { + parts.push_back(result->data); + } + + json raw_data = result->data; // Access raw data for token ID + + for (const auto& r : parts) { + json item = r; + // Attach raw token ID for banning logic + if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; + items_to_buffer.push_back(item); + } + } else { + items_to_buffer.push_back(result->to_json()); + } + + // 2. Process items into buffer + for (const auto& item : items_to_buffer) { + token_buffer.push_back(item); + } + + print_debug_buffer(token_buffer); + + // 3. Check for Stop Phrases (Strings & Regex) + std::string buffer_text = ""; + std::vector token_offsets; + + for (const auto& item : token_buffer) { + token_offsets.push_back(buffer_text.length()); + buffer_text += get_content_str(item); + } + + std::string buffer_lower = to_lower_str(buffer_text); + + size_t match_pos = std::string::npos; + std::string detected_phrase = ""; + + // A. Check Strings (Case Insensitive) + for (const auto& phrase : stop_phrases) { + std::string target_lower = to_lower_str(phrase); + size_t pos = buffer_lower.find(target_lower); + if (pos != std::string::npos) { + if (match_pos == std::string::npos || pos < match_pos) { + match_pos = pos; + detected_phrase = phrase; + } + } + } + + // B. Check Regex + for (size_t i = 0; i < stop_regexes.size(); ++i) { + std::smatch match; + // We search the raw buffer_text + if (std::regex_search(buffer_text, match, stop_regexes[i])) { + size_t pos = match.position(0); + if (match_pos == std::string::npos || pos < match_pos) { + match_pos = pos; + detected_phrase = "REGEX:" + regex_patterns[i]; + } + } + } + + if (match_pos != std::string::npos) { + std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; + + // Find the guilty token + size_t split_index = 0; + bool found_split = false; + for (size_t i = 0; i < token_offsets.size(); ++i) { + size_t token_start = token_offsets[i]; + std::string content = get_content_str(token_buffer[i]); + size_t token_end = token_start + content.length(); + + if (token_end > match_pos) { + split_index = i; + found_split = true; + break; + } + } + + if (found_split) { + // 1. Construct prompt from good tokens (DO NOT FLUSH) + std::string temp_prompt_suffix = ""; + std::deque good_tokens; + + for (size_t i = 0; i < split_index; ++i) { + json& item = token_buffer[i]; + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + temp_prompt_suffix += get_content_str(item); + good_tokens.push_back(item); + } + + // 2. Identify Guilty Token & Add to Bans + json& guilty_item = token_buffer[split_index]; + int guilty_token_id = get_token_id(guilty_item); + + if (guilty_token_id == -1) { + std::string content = get_content_str(guilty_item); + auto tokens = ctx_server.tokenize(content, false); + if (!tokens.empty()) guilty_token_id = tokens[0]; + } + + if (guilty_token_id != -1) { + // Check if we are banning a different slot than before + if (ban_slot_index != (int)split_index) { + current_step_bans.clear(); + ban_slot_index = (int)split_index; + } + + current_step_bans.insert(guilty_token_id); + std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; + + // 3. Cancel current task + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + + // 4. FIX STEP: Generate 1 token with ALL current bans + json fix_data = data; + fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; + fix_data["n_predict"] = 1; + + // Robust logit_bias handling + if (!fix_data.contains("logit_bias")) { + fix_data["logit_bias"] = json::array(); + } + + if (fix_data["logit_bias"].is_array()) { + for (int banned_id : current_step_bans) { + fix_data["logit_bias"].push_back(json::array({banned_id, ban_bias})); + } + } else if (fix_data["logit_bias"].is_object()) { + for (int banned_id : current_step_bans) { + fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; + } + } + + std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; + + int id_fix = ctx_server.queue_tasks.get_new_id(); + *active_task_id = id_fix; // Update shared state for fix task + ctx_server.queue_results.add_waiting_task_id(id_fix); + + std::vector fix_inputs = tokenize_input_prompts( + llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true + ); + + // Use helper + post_task_with_params(id_fix, fix_data, fix_inputs[0]); + + // Wait for the fix token + std::unordered_set fix_ids = { id_fix }; + server_task_result_ptr fix_result = nullptr; + while (!fix_result) { + fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); + if (!fix_result && !sink.is_writable()) { + request_cancel(id_fix); + ctx_server.queue_results.remove_waiting_task_id(id_fix); + return false; + } + } + ctx_server.queue_results.remove_waiting_task_id(id_fix); + + // Check for error in fix result + if (fix_result->is_error()) { + std::cout << "Debug: Fix task failed with error." << std::endl; + sse(fix_result->to_json()); + return false; + } + + // Process fix token + json fix_token_json; + json raw_fix = fix_result->data; + + // Use format_partial_response_oaicompat for fix token too + if (oaicompat) { + std::vector parts = format_partial_response_oaicompat(*fix_result, completion_id); + if (!parts.empty()) fix_token_json = parts[0]; + } else { + fix_token_json = fix_result->data; + } + + if (raw_fix.contains("token")) fix_token_json["__raw_token_id"] = raw_fix["token"]; + + std::string fix_content = get_content_str(fix_token_json); + + // 5. RESUME STEP: Continue generation normally + json resume_data = data; + bool stop_after_fix = false; + + if (original_n_predict > 0) { + int pending = good_tokens.size() + 1; + if (total_tokens_streamed + pending >= original_n_predict) { + stop_after_fix = true; + } else { + resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); + } + } + + if (stop_after_fix) { + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + + while (!token_buffer.empty()) { + json& item = token_buffer.front(); + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + if (!sse(item)) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + total_tokens_streamed++; + token_buffer.pop_front(); + } + successful_completion = true; + goto cleanup; + } + + resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + fix_content; + + current_task_id = ctx_server.queue_tasks.get_new_id(); + *active_task_id = current_task_id; // Update shared state for resume task + ctx_server.queue_results.add_waiting_task_id(current_task_id); + + std::vector resume_inputs = tokenize_input_prompts( + llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true + ); + + // Use helper + post_task_with_params(current_task_id, resume_data, resume_inputs[0]); + + // 6. Update Buffer: Good Tokens + Fix Token + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + + // REMOVED continue; to allow flush logic to run + } + } + } + + // 4. Standard Flush Logic + bool should_flush_all = result->is_stop() || result->is_error(); + + if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { + while (!token_buffer.empty()) { + if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) { + break; + } + + json& item_to_send = token_buffer.front(); + if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); + + current_prompt_str += get_content_str(item_to_send); + + // SMART BAN CLEARING LOGIC + if (ban_slot_index != -1) { + if (0 == ban_slot_index) { + // We are flushing the slot that had bans. + // This means it's now accepted (or we are forced to flush). + current_step_bans.clear(); + ban_slot_index = -1; + } else { + // We are flushing a preceding token. + // The banned slot shifts left. + ban_slot_index--; + } + } + + if (!sse(item_to_send)) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + total_tokens_streamed++; + token_buffer.pop_front(); + + if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + successful_completion = true; + goto cleanup; + } + } + } + + if (result->is_error()) { + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + if (result->is_stop()) { + successful_completion = true; + break; + } + } + } + + cleanup: + bool ok = true; + if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string done_message = "data: [DONE]\n\n"; + LOG_VERBOSE("data stream", { {"to_send", done_message} }); + if (!sink.write(done_message.c_str(), done_message.size())) { + ok = false; + } + } + sink.done(); + + // Cleanup the active task ID (which might be different from id_task in slow path) + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + + return ok; + } catch (const std::exception& e) { + // Catch any exceptions to prevent crashing the server + std::cerr << "Exception in streaming handler: " << e.what() << std::endl; + sse(json{{"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}}}); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } catch (...) { + std::cerr << "Unknown exception in streaming handler" << std::endl; + sse(json{{"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}}}); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } + }; + + auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { + // Cancel the currently active task ID + int id_to_cancel = *active_task_id; + request_cancel(id_to_cancel); + ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }; + + // Execute the complex logic + buffer_and_check_string_ban_and_rewind_logic(); + } + }; + + + + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + auto data = json::parse(req.body); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); + }; + + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request& req, httplib::Response& res) { + auto body = json::parse(req.body); + json data = oaicompat_chat_params_parse(body); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_COMPLETION); + }; + + const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; + + res.set_content(models.dump(), "application/json; charset=utf-8"); + }; + + + + const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + auto body = json::parse(req.body); + std::vector files; + json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files); + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_CHAT); + }; + + const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = anthropic_params_from_json( + ctx_server.model, + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_ANTHROPIC); + }; + + const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + json body = json::parse(req.body); + + // Parse the Anthropic request (max_tokens is not required for count_tokens) + json body_parsed = anthropic_params_from_json( + ctx_server.model, + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(llama_get_vocab(ctx_server.ctx), prompt, true, true); + + res_ok(res, {{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + + // same with handle_chat_completions, but without inference part + const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + auto body = json::parse(req.body); + std::vector files; // dummy, unused + json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files); + res_ok(res, { { "prompt", std::move(data.at("prompt")) } }); + }; + + const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + const int id_task = ctx_server.queue_tasks.get_new_id(); + server_tokens token; // dummy tokens + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, true, false, std::move(token)); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + std::vector tokens; + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + tokens = ctx_server.tokenize(body.at("content"), add_special); + } + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const std::vector tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens); + } + + const json data = format_detokenized_response(content); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) { + if (!ctx_server.params_base.embedding) { + res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } + else if (body.contains("content")) { + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } + else { + res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } + else if (format != "float") { + res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + auto vocab = llama_get_vocab(ctx_server.ctx); + auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true); + for (const auto& tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.embd_normalize = embd_normalize; + task.embedding = true; // probably not needed + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.is_connection_closed); + + // collect results + if (all_results.is_terminated) { + llama_decode_stop(); + return; // connection is closed + } + else if (all_results.error) { + res_err(res, all_results.error->to_json()); + return; + } + else { + for (auto& res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res_ok(res, root); + + }; + + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + }; + + + const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { + json result = json::array(); + for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { + auto & la = ctx_server.lora_adapters[i]; + result.push_back({ + {"id", i}, + {"path", la.path}, + {"scale", la.scale}, + }); + } + res.set_content(result.dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + + const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { + const std::vector body = json::parse(req.body); + int max_idx = ctx_server.lora_adapters.size(); + + // clear existing value + for (auto & la : ctx_server.lora_adapters) { + la.scale = 0.0f; + } + + // set value + for (auto entry : body) { + int id = entry.at("id"); + float scale = entry.at("scale"); + if (0 <= id && id < max_idx) { + ctx_server.lora_adapters[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + server_task task; + task.type = SERVER_TASK_TYPE_SET_LORA; + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + res.set_content(result.data.dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + json response = json::array(); + + try { + for (const auto& entry : fs::directory_iterator(params.slot_save_path)) { + if (!entry.is_regular_file() || entry.file_size() < 12) { + continue; + } + + std::ifstream file(entry.path(), std::ios::binary); + if (!file) continue; + + uint32_t magic, version, n_token_count; + file.read(reinterpret_cast(&magic), sizeof(magic)); + file.read(reinterpret_cast(&version), sizeof(version)); + file.read(reinterpret_cast(&n_token_count), sizeof(n_token_count)); + + if (magic != LLAMA_STATE_SEQ_MAGIC || + version != LLAMA_STATE_SEQ_VERSION || + entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) { + continue; + } + + std::vector tokens(n_token_count); + file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); + + //C++17 is not modern enough to have a nice and portable way to get the mtime of a file + //so the following seems to be needed + auto ftime = fs::last_write_time(entry.path()); + auto system_time = std::chrono::time_point_cast( + ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now() + ); + std::time_t c_time = std::chrono::system_clock::to_time_t(system_time); + std::tm tm_struct; + #if defined(_WIN32) + localtime_s(&tm_struct, &c_time); + #else + localtime_r(&c_time, &tm_struct); + #endif + std::ostringstream oss; + oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S"); + auto str_time = oss.str(); + + + response.push_back({ + {"filename", entry.path().filename().string()}, + {"filesize", entry.file_size()}, + {"mtime", str_time}, + {"token_count", n_token_count}, + {"prompt", tokens_to_str(ctx_server.ctx, tokens)} + }); + } + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + json response = json::array(); + for (server_slot & slot : ctx_server.slots) { + response.push_back({ + {"slot_id", slot.id}, + {"token_count", slot.cache_tokens.size()}, + {"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) } + }); + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + + const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string filename_str = body.at("filename"); + + // prevent directory traversal attacks + if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str); + + if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) { + res.status = 404; + response = {{"error", "File not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::remove(file_to_delete)) { + response = { + {"status", "deleted"}, + {"filename", filename_str} + }; + } else { + res.status = 500; + response = {{"error", "Failed to delete the file."}}; + } + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'filename' key in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string old_filename_str = body.at("old_filename"); + const std::string new_filename_str = body.at("new_filename"); + + if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos || + new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str; + const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str; + + if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) { + res.status = 404; + response = {{"error", "Source file not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::exists(new_path)) { + res.status = 409; + response = {{"error", "Destination filename already exists."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + std::error_code ec; + fs::rename(old_path, new_path, ec); + + if (ec) { + res.status = 500; + response = {{"error", "Failed to rename file: " + ec.message()}}; + } else { + response = { + {"status", "renamed"}, + {"old_filename", old_filename_str}, + {"new_filename", new_filename_str} + }; + } + + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; +#ifdef SQLITE3_MODERN_CPP_SUPPORT + const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(), + "application/json" + ); + }; +#else + const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(), + "application/json" + ); + }; +#endif + +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + try { + const json body = !req.body.empty() ? json::parse(req.body) : json::object(); + func(*db_handle, body, req, res); + } catch(const std::exception& e) { + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", e.what()}}.dump(), + "application/json" + ); + } + }; + }; +#else + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(), + "application/json" + ); + }; + }; +#endif + + const auto normalize_store_name = [](const std::string& storeName) { + if(storeName.empty()) return std::string("sessions"); + + std::string normalized; + normalized.reserve(storeName.size()); + + for(char c : storeName) { + if(std::isalpha(static_cast(c))) { + normalized.push_back(std::tolower(static_cast(c))); + } + } + + return normalized.empty() ? "sessions" : normalized; + }; + + const auto get_key_string = [](const json& j) { + return j.is_string() ? j.get() : j.dump(); + }; + + + const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + std::string data; + const std::string store = normalize_store_name(body["storeName"]); + db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data; + if(data.empty()) { + res.status = 404; + res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json"); + } else { + json response{{"ok", true}}; + response["result"] = (store == "names") ? json(data) : json::parse(data); + res.set_content(response.dump(), "application/json"); + } + }); + + const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + const std::string store = normalize_store_name(body["storeName"]); + const std::string data = (store == "names") ? body["data"].get() : body["data"].dump(); + db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data; + res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json"); + }); + + const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "UPDATE names SET data = ? WHERE key = ?" + << body["newName"].get() + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json"); + }); + + const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >> + [&](const std::string& key, const std::string& data) { + result[key] = json::parse(data); + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) { + result[key] = data; + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?" + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json"); + }); + + const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "VACUUM"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) { + result[id] = config; + }; + res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json"); + }); + + const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string data; + if (body["duration"].is_null()) { + db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get() >> data; + } + else { + db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get() << body["db_load"].get() >> data; + } + json response{{"ok", true}}; + response["result"] = json::parse(data); + res.set_content(response.dump(), "application/json"); + }); + + const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) { + db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get() + "\",\"column\": \"" + body["column"].get() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}')"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}"; + db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')"; + res.set_content(json{{"ok", true}}.dump(), "application/json"); + }); + + // + // Router + // + if (params.webui == COMMON_WEBUI_NONE) { + LLAMA_LOG_INFO("Web UI is disabled\n"); + } + else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + svr->set_base_dir(params.public_path); + } + + { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = svr->set_mount_point("/", params.public_path); + if (!is_found) { + GGML_ABORT("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } + else { + + // using embedded static index.html + svr->Get("/", [params](const httplib::Request& req, httplib::Response& res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } + else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + if (params.webui == COMMON_WEBUI_AUTO) { + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + else if (params.webui == COMMON_WEBUI_LLAMACPP) { + res.set_content(reinterpret_cast(index_llamacpp_html_gz), index_llamacpp_html_gz_len, "text/html; charset=utf-8"); + } + else { + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + } + return false; + }); + } + } + } + // register API routes + svr->Get ("/health", handle_health); + svr->Get ("/metrics", handle_metrics); + svr->Get ("/props", handle_props); + svr->Get("/v1/props", handle_props_simple); + svr->Get ("/v1/models", handle_models); + svr->Post("/completion", handle_completions); // legacy + svr->Post("/completions", handle_completions); // legacy + svr->Post("/v1/completions", handle_completions_oai); + svr->Post("/chat/completions", handle_chat_completions); + svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/v1/messages", handle_anthropic_messages); + svr->Post("/v1/messages/count_tokens", handle_anthropic_count_tokens); + svr->Post("/infill", handle_infill); + svr->Post("/embedding", handle_embeddings); // legacy + svr->Post("/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); + svr->Post("/tokenize", handle_tokenize); + svr->Post("/detokenize", handle_detokenize); + svr->Post("/apply-template", handle_apply_template); + // LoRA adapters hotswap + svr->Get ("/lora-adapters", handle_lora_adapters_list); + svr->Post("/lora-adapters", handle_lora_adapters_apply); + // Save & load slots + svr->Get ("/slots", handle_slots); + svr->Get ("/slots/list", list_slot_prompts); + if (!params.slot_save_path.empty()) { + // these endpoints rely on slot_save_path existing + svr->Post("/slots/:id_slot", handle_slots_action); + svr->Get ("/list", list_saved_prompts); + svr->Post("/delete_prompt", delete_saved_prompt); + svr->Post("/rename_prompt", rename_saved_prompt); + + } + + svr->Get ("/version", handle_version); + if (!params.sql_save_file.empty()) { + // these endpoints rely on sql_save_file existing + svr->Post("/load", handle_load); + svr->Post("/save", handle_save); + svr->Post("/rename", handle_rename); + svr->Post("/all", handle_all); + svr->Post("/sessions", handle_sessions); + svr->Get ("/sessions", handle_sessions); + svr->Post("/delete", handle_delete); + //VACUUM is there for the extension but does not require the extension + svr->Get ("/vacuum", handle_vacuum); +#ifdef SQLITE3_MODERN_CPP_SUPPORT + if (sqlite_extension_loaded) { + svr->Get ("/zstd_get_configs", handle_zstd_get_configs); + svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance); + svr->Post("/zstd_enable_transparent", handle_zstd_enable); + svr->Post("/zstd_update_transparent", handle_zstd_config_update); + } +#endif + } + // + // Start the server + // + if (params.n_threads_http < 1) { + // +2 threads for monitoring endpoints + params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + log_data["n_threads_http"] = std::to_string(params.n_threads_http); + svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; + + LOG_INFO("HTTP server listening", log_data); + + // run the HTTP server in a thread - see comment below + std::thread t([&]() { + if (!svr->listen_after_bind()) { + state.store(SERVER_STATE_ERROR); + return 1; + } + + return 0; + }); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + ctx_server.queue_tasks.start_loop(); + + svr->stop(); + t.join(); + + llama_backend_free(); + + return 0; +} From f22318ea39a9101d5accf27b200aa95449cc764e Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:22:49 +0100 Subject: [PATCH 2/8] CRLF quickfix --- examples/server/server.cpp | 5532 ++++++++++++++++++------------------ 1 file changed, 2766 insertions(+), 2766 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 81a8006f1..6e7657fe8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,2766 +1,2766 @@ -#pragma warning(disable : 4996) -#include "server-context.h" -#include "server-common.h" -#include "chat.h" - -#include "common.h" -#include "speculative.h" -#include "mtmd.h" -#include "sampling.h" -#include "llama.h" -#include "llama-vocab.h" -#include -#include -#include - -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - - -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif - -#include -#include "index.html.gz.hpp" -#include "index_llamacpp.html.gz.hpp" -#include "loading.html.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef SQLITE3_MODERN_CPP_SUPPORT -#include - -struct DatabaseHandle { - sqlite::database db; - - DatabaseHandle(const std::string& path) : db(path) { - db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; - db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; - db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; - } -}; -#endif - -using json = nlohmann::ordered_json; -namespace fs = std::filesystem; -constexpr int HTTP_POLLING_SECONDS = 1; - -bool server_verbose = false; -bool server_log_json = true; - - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed -}; - - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - - -inline std::string get_model_name(std::string path) -{ - std::string filename = path.substr(path.find_last_of("/\\") + 1); - return filename; -}; - - -static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) { - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}} }) - : json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}} }); - - std::time_t t = std::time(0); - - json res = json{ - {"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}, - {"id", completion_id} - }; - - if (server_verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(server_task_result task_result, const std::string& completion_id) { - json result = task_result.data; - std::cout << result.dump(4) << std::endl; - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({ result }); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({ json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}} }); - } - else { - if (first) { - if (content.empty()) { - choices = json::array({ json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}} }); - } - else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{ {"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} }; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} }; - - return std::vector({ initial_ret, second_ret }); - } - } - else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({ json::object() }); - } - - choices = json::array({ json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - } }); - } - } - - json ret = json{ - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; - - if (task_result.timings.prompt_n != -1) { - ret.push_back({ "timings", task_result.timings.to_json() }); - } - - // - if (!finish_reason.empty()) { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({ "usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - } }); - } - - return std::vector({ ret }); -} - - -static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto& elem : embeddings) { - json embedding_obj; - - if (use_base64) { - const auto& vec = json_value(elem, "embedding", json::array()).get>(); - const char* data_ptr = reinterpret_cast(vec.data()); - size_t data_size = vec.size() * sizeof(float); - embedding_obj = { - {"embedding", base64::encode(data_ptr, data_size)}, - {"index", i++}, - {"object", "embedding"}, - {"encoding_format", "base64"} - }; - } - else { - embedding_obj = { - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - } - data.push_back(embedding_obj); - n_tokens += json_value(elem, "tokens_evaluated", 0); - } - json res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json { - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"data", data} - }; - - return res; -} - -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health" || req.path == "/v1/completions") { - return; - } - - LOG_INFO("request", { - {"remote_addr", req.remote_addr}, - {"remote_port", req.remote_port}, - {"status", res.status}, - {"method", req.method}, - {"path", req.path}, - {"params", req.params}, - }); - - LOG_VERBOSE("request", { - {"request", req.body}, - {"response", res.body}, - }); -} - -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context& ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector&& tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function& should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } - else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function& should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->error) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto& id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } - else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; - -auto res_err = [](httplib::Response& res, json error_data) { - json final_response{ {"error", error_data} }; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); -}; - -auto res_ok = [](httplib::Response& res, const json& data) { - res.set_content(data.dump(), "application/json; charset=utf-8"); - res.status = 200; -}; - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); - } - - shutdown_handler(signal); -} - -int main(int argc, char ** argv) { -#if SERVER_VERBOSE != 1 - log_disable(); -#endif - // own arguments required by this example - gpt_params params; - - if (!gpt_params_parse(argc, argv, params)) { - gpt_params_print_usage(argc, argv, params); - return 1; - } - - // parse arguments from environment variables - gpt_params_parse_from_env(params); - - // TODO: not great to use extern vars - server_log_json = params.log_json; - server_verbose = params.verbosity > 0; - - - // struct that contains llama context and inference - server_context ctx_server; - - if (!params.system_prompt.empty()) { - ctx_server.system_prompt_set(params.system_prompt); - } - - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INFO("build info", { - {"build", LLAMA_BUILD_NUMBER}, - {"commit", LLAMA_COMMIT} - }); - - LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); - - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}}); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INFO("Running without SSL", {}); - svr.reset(new httplib::Server()); - } -#else - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "ik_llama.cpp"}}); - - svr->set_logger(log_server_request); - - - - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { - std::string message; - try { - std::rethrow_exception(std::move(ep)); - } catch (std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_VERBOSE("Got exception", formatted_error); - res_err(res, formatted_error); - }); - - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_err() - }); - - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); - - if (!svr->bind_to_port(params.hostname, params.port)) { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); - return 1; - } - - std::unordered_map log_data; - - log_data["hostname"] = params.hostname; - log_data["port"] = std::to_string(params.port); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); - } else if (params.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; - } - - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - ctx_server.cache_ram_n_min = params.cache_ram_n_min; - ctx_server.cache_ram_similarity = params.cache_ram_similarity; -#ifdef SQLITE3_MODERN_CPP_SUPPORT - auto db_handle = std::make_shared(params.sql_save_file); - bool sqlite_extension_loaded = false; - if (!params.sqlite_zstd_ext_file.empty()) { - auto* conn = db_handle->db.connection().get(); - sqlite3_enable_load_extension(conn, 1); - char* errmsg = nullptr; - const int rc = sqlite3_load_extension( - conn, - params.sqlite_zstd_ext_file.c_str(), - nullptr, - &errmsg - ); - if(rc != SQLITE_OK) { - const std::string err = errmsg ? errmsg : "Unknown extension error"; - sqlite3_free(errmsg); - LOG_WARNING("Failed to load extension", {{"err", err}}); - } - else { - sqlite_extension_loaded = true; - } - sqlite3_enable_load_extension(conn, 0); - } -#else - auto db_handle = false; -#endif - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // print sample chat example to make it clear which template is used - - LOG_INFO("chat template", { - {"chat_template", common_chat_templates_source(ctx_server.chat_templates.get())}, - }); - - LOG_INFO("chat template", { - {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, {}).c_str() - }, - {"built_in", params.chat_template.empty()}, - }); - // - // Middlewares - // - - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (params.api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { - return true; // API key is valid - } - } - - auth_header = req.get_header_value("X-Api-Key"); - - if (std::find(params.api_keys.begin(), params.api_keys.end(), auth_header) != params.api_keys.end()) { - return true; // API key is valid - } - - // API key is invalid or not provided - res.status = 401; - res.set_content( - (json { - {"error", { - {"message", "Invalid API Key"}, - {"type", "authentication_error"}, - {"code", 401} - }} - }).dump(-1, ' ', false, json::error_handler_t::replace), - "application/json; charset=utf-8" - ); - LOG_WARNING("Unauthorized: Invalid API Key\n", {}); - return false; - }; - - auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) { - server_state current_state = state.load(); - if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); - if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); - res.status = 503; - } - else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; - } - else { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } - return false; - } - return true; - }; - - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - // - // Route handlers (or controllers) - // - - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data.at("idle"); - const int n_processing_slots = result.data.at("processing"); - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (params.endpoint_slots && req.has_param("include_slots")) { - health["slots"] = result.data.at("slots"); - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } break; - case SERVER_STATE_ERROR: - { - res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); - } break; - } - }; - - const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { - if (!params.endpoint_slots) { - res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - res.set_content(result.data.at("slots").dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { - if (!params.endpoint_metrics) { - res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(std::move(task)); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - json data = result.data; - - const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); - const uint64_t t_prompt_processing = data.at("t_prompt_processing"); - - const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); - const uint64_t t_tokens_generation = data.at("t_tokens_generation"); - - const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data.at("n_tokens_predicted_total")} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", (uint64_t) data.at("kv_cache_tokens_count")} - },{ - {"name", "requests_processing"}, - {"help", "Number of request processing."}, - {"value", (uint64_t) data.at("processing")} - },{ - {"name", "requests_deferred"}, - {"help", "Number of request deferred."}, - {"value", (uint64_t) data.at("deferred")} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - const int64_t t_start = data.at("t_start"); - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK - }; - - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_SAVE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath } - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_RESTORE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath } - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_ERASE; - task.data = { - { "id_slot", id_slot }, - }; - - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - if (result.error) { - res_err(res, result.data); - } else { - res.set_content(result.data.dump(), "application/json"); - } - }; - - const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - std::string id_slot_str = req.path_params.at("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - std::string action = req.get_param_value("action"); - - if (action == "save") { - handle_slots_save(req, res, id_slot); - } else if (action == "restore") { - handle_slots_restore(req, res, id_slot); - } else if (action == "erase") { - handle_slots_erase(req, res, id_slot); - } else { - res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - } - }; - - const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - std::string template_key = "tokenizer.chat_template", curr_tmpl; - int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); - } - } - json data = { - { "system_prompt", ctx_server.system_prompt.c_str() }, - { "model_alias", ctx_server.params_base.model_alias }, - { "model_path", ctx_server.params_base.model}, - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_name", get_model_name(ctx_server.params_base.model)}, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), /* special= */ true)}, - { "model_path", ctx_server.params_base.model }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "n_ctx", ctx_server.n_ctx } - - }; - - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_props_simple = [&ctx_server](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - int n_past = 0; - int slot_id = 0; - for (server_slot& slot : ctx_server.slots) { - if (slot.n_past > n_past) { - n_past = slot.n_past; - slot_id = slot.id; - } - } - json data = { - { "model_name", get_model_name(ctx_server.params_base.model)}, - { "model_path", ctx_server.params_base.model }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "n_ctx", ctx_server.n_ctx } - }; - res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - - - -// handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server, ¶ms]( - server_task_type type, - json& data, - const std::vector& files, - const std::function& is_connection_closed, - httplib::Response& res, - oaicompat_type oaicompat) -> void { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - // ---------------------------------------------------------------- - // 1. Regex Validation - // ---------------------------------------------------------------- - auto validate_regex_list = [&](const std::string& field_name) -> std::string { - if (data.contains(field_name) && data[field_name].is_array()) { - for (const auto& val : data[field_name]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) { - try { - std::regex re(s); - } catch (const std::regex_error& e) { - return s; - } - } - } - } - } - return ""; - }; - - std::string invalid_re = validate_regex_list("banned_regex"); - if (invalid_re.empty()) invalid_re = validate_regex_list("banned_regex_case_insensitive"); - - if (!invalid_re.empty()) { - res_err(res, format_error_response("Invalid regex: " + invalid_re, ERROR_TYPE_INVALID_REQUEST)); - return; - } - - const auto completion_id = gen_chatcmplid(); - - // Process prompt / inputs - std::vector inputs; - try { - const auto& prompt = data.at("prompt"); - if (oaicompat && ctx_server.mctx != nullptr) { - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } - else { - inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); - } - } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } - - // ---------------------------------------------------------------- - // Check if we need the complex "Banned String" logic - // Only enable if the lists are present AND contain actual strings. - // ---------------------------------------------------------------- - auto list_has_content = [&](const std::string& key) { - if (data.contains(key) && data[key].is_array()) { - for (const auto& item : data[key]) { - if (item.is_string() && !item.get().empty()) { - return true; - } - } - } - return false; - }; - - bool has_banned_content = list_has_content("banned_strings") || - list_has_content("banned_regex") || - list_has_content("banned_regex_case_insensitive"); - - if (!has_banned_content) { - // ---------------------------------------------------------------- - // PATH A: Standard Logic (server_response_reader) - // ---------------------------------------------------------------- - - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); - - try { - std::vector tasks; - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.data = data; - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); - tasks.push_back(std::move(task)); - } - - rd->post_tasks(std::move(tasks)); - } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } - bool stream = json_value(data, "stream", false); - if (!stream) { - // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); - if (all_results.is_terminated) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } - else { - json arr = json::array(); - for (auto& res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - if (oaicompat) { - arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); - } else { - arr.push_back(res->to_json()); - } - } - // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); - } - } - else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); - if (first_result == nullptr) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; - } - else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // Prepare first result JSON (handling OAI format if needed) - std::vector first_result_parts; - if (oaicompat) { - first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); - } else { - first_result_parts.push_back(first_result->to_json()); - } - - const auto chunked_content_provider = [first_result_parts, rd, oaicompat, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { - const auto sse = [oaicompat, &sink](const json& res) { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, res); - } - else { - return server_sent_event(sink, res); - } - }; - - // flush the first result parts - for (auto& part : first_result_parts) { - if (!part.empty()) { - if (!sse(part)) { - sink.done(); - return false; // sending failed, go to on_complete() - } - part.clear(); // mark as sent - } - } - - // receive subsequent results - auto result = rd->next([&sink] { return !sink.is_writable(); }); - if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() - } - - // send the results - bool ok = false; - if (result->is_error()) { - ok = sse(json{ { "error", result->to_json() } }); - sink.done(); - return false; // go to on_complete() - } - else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - - if (oaicompat) { - std::vector parts = format_partial_response_oaicompat(*result, completion_id); - for (const auto& part : parts) { - ok = sse(part); - if (!ok) break; - } - } else { - ok = sse(result->to_json()); - } - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() - } - - // has next data, continue - return true; - }; - - auto on_complete = [rd](bool) { - rd->stop(); - }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - - } else { - // ---------------------------------------------------------------- - // PATH B: Banned Content Logic (Slow Path with Buffering & Rewind) - // ---------------------------------------------------------------- - auto buffer_and_check_string_ban_and_rewind_logic = [&]() { - // Helper to mimic request_cancel using the task queue directly - auto request_cancel = [&ctx_server](int id_target) { - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_target; - std::vector tasks; - tasks.push_back(std::move(task)); - ctx_server.queue_tasks.post(std::move(tasks), true); - }; - - // Helper to post a completion task with correct OAI params - auto post_task_with_params = [&ctx_server, oaicompat, completion_id](int id_task, json& task_data, server_tokens& tokens) { - server_task task(SERVER_TASK_TYPE_COMPLETION); - task.id = id_task; - task.index = 0; - task.tokens = std::move(tokens); - task.data = task_data; - task.id_slot = json_value(task_data, "id_slot", -1); - - // Critical: Set OAI params so worker generates correct output format - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); - - std::vector tasks; - tasks.push_back(std::move(task)); - ctx_server.queue_tasks.post(std::move(tasks)); - }; - - const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - - // Use helper instead of request_completion - post_task_with_params(id_task, data, inputs[0]); - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // Non-streaming: wait for result (using pointer to avoid slicing) - std::unordered_set ids = { id_task }; - server_task_result_ptr result = nullptr; - - // Simple blocking wait - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && is_connection_closed()) { - request_cancel(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - return; - } - } - - if (!result->is_error()) { - json result_json; - if (oaicompat) { - result_json = format_final_response_oaicompat(data, result->data, completion_id, false); - } else { - result_json = result->to_json(); - } - res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } - else { - res_err(res, result->to_json()); - } - ctx_server.queue_results.remove_waiting_task_id(id_task); - } - else { - // Shared state to track the currently running task ID across retries. - auto active_task_id = std::make_shared(id_task); - - // Capture 'data' by value to use as a template for retries - const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, send_done = params.send_done, data, request_cancel, post_task_with_params](size_t, httplib::DataSink& sink) mutable { - // Define sse here so it's visible to both try and catch blocks - const auto sse = [oaicompat, &sink](const json &res) { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, res); - } else { - return server_sent_event(sink, res); - } - }; - - try { - bool successful_completion = false; - - // 1. Parse Configuration from Request - - // Banned Strings - std::vector stop_phrases; - if (data.contains("banned_strings") && data["banned_strings"].is_array()) { - for (const auto& val : data["banned_strings"]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) stop_phrases.push_back(s); - } - } - } - - // Sort banned strings by length (descending) - std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { - return a.length() > b.length(); - }); - - // Banned Regex (Case Sensitive & Insensitive) - std::vector regex_patterns; // For buffer size calculation - std::vector stop_regexes; // Compiled regexes - - auto add_regex_list = [&](const std::string& field_name, bool case_insensitive) { - if (data.contains(field_name) && data[field_name].is_array()) { - for (const auto& val : data[field_name]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) { - auto flags = std::regex_constants::ECMAScript; - if (case_insensitive) flags |= std::regex_constants::icase; - stop_regexes.emplace_back(s, flags); - regex_patterns.push_back(s); - } - } - } - } - }; - - // We assume validation passed in handle_completions_impl, so no try-catch needed here - add_regex_list("banned_regex", false); - add_regex_list("banned_regex_case_insensitive", true); - - // Logit Bias Penalty (Default: -10000.0) - float ban_bias = -10000.0f; - if (data.contains("banned_bias") && data["banned_bias"].is_number()) { - ban_bias = data["banned_bias"].get(); - } - - // Manual Buffer Size - size_t manual_buffer_size = 0; - if (data.contains("banbuffer_size") && data["banbuffer_size"].is_number_unsigned()) { - manual_buffer_size = data["banbuffer_size"].get(); - } - - // Token Limit Tracking - int original_n_predict = -1; - if (data.contains("n_predict") && data["n_predict"].is_number_integer()) { - original_n_predict = data["n_predict"].get(); - } - int total_tokens_streamed = 0; - - // ============================================================ - // FAST PATH: No banned strings AND No regex -> No buffering - // ============================================================ - if (stop_phrases.empty() && stop_regexes.empty()) { - while (true) { - std::unordered_set ids = { *active_task_id }; - server_task_result_ptr result = nullptr; - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && !sink.is_writable()) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - - if (!result->is_error()) { - // Use format_partial_response_oaicompat to get the correct chunks - std::vector parts; - if (oaicompat) { - parts = format_partial_response_oaicompat(*result, completion_id); - } else { - parts.push_back(result->data); - } - - for (const auto& item : parts) { - if (!sse(item)) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - - if (result->is_stop()) { - successful_completion = true; - break; - } - } else { - sse(result->to_json()); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - } - // ============================================================ - // SLOW PATH: Buffering and Banning Logic - // ============================================================ - else { - // Calculate Buffer Size - size_t BUFFER_SIZE; - if (manual_buffer_size > 0) { - BUFFER_SIZE = manual_buffer_size; - } else { - size_t max_len = 0; - // Check strings - if (!stop_phrases.empty()) { - max_len = stop_phrases[0].length(); // First is longest due to sort - } - // Check regex patterns - for (const auto& pat : regex_patterns) { - if (pat.length() > max_len) max_len = pat.length(); - } - - // Default: Longest string/regex + 1 - BUFFER_SIZE = std::max((size_t)1, max_len + 1); - } - - // Initialize Buffer & State - std::deque token_buffer; - - int current_task_id = id_task; - - // Track bans specifically for the current "next token" to be generated. - std::set current_step_bans; - int ban_slot_index = -1; - - // Track the text that has been confirmed/sent to the client. - std::string current_prompt_str = ""; - if (data.contains("prompt") && data["prompt"].is_string()) { - current_prompt_str = data["prompt"].get(); - } - - // Helper to extract text content - auto get_content_str = [](const json& j) -> std::string { - if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) { - const auto& choice = j["choices"][0]; - if (choice.contains("delta") && choice["delta"].contains("content")) { - auto val = choice["delta"]["content"]; - if (val.is_string()) return val.get(); - } - } - if (j.contains("content")) { - auto val = j["content"]; - if (val.is_string()) return val.get(); - } - return ""; - }; - - // Helper to extract Token ID - auto get_token_id = [](const json& j) -> int { - if (j.contains("__raw_token_id")) return j["__raw_token_id"].get(); - if (j.contains("token")) return j["token"].get(); - if (j.contains("id")) return j["id"].get(); - return -1; - }; - - // Helper for case-insensitive search - auto to_lower_str = [](std::string s) { - std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c){ return std::tolower(c); }); - return s; - }; - - // Helper to print buffer - auto print_debug_buffer = [&](const std::deque& buf) { - std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; - size_t print_len = std::max(buf.size(), BUFFER_SIZE); - for (size_t i = 0; i < print_len; ++i) { - if (i < buf.size()) { - std::string content = get_content_str(buf[i]); - std::string escaped; - for (char c : content) { - if (c == '\n') escaped += "\\n"; - else if (c == '"') escaped += "\\\""; - else escaped += c; - } - std::cout << "\"" << escaped << "\""; - } else { - std::cout << "\"\""; - } - if (i < print_len - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - }; - - while (true) { - // Ensure shared state matches current local state - *active_task_id = current_task_id; - - // 0. Check connection status explicitly - if (!sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - // Receive from the CURRENT task ID using pointer to avoid slicing - std::unordered_set ids = { current_task_id }; - server_task_result_ptr result = nullptr; - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && !sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - } - - std::vector items_to_buffer; - - if (!result->is_error()) { - // Use format_partial_response_oaicompat to get the correct chunks - std::vector parts; - if (oaicompat) { - parts = format_partial_response_oaicompat(*result, completion_id); - } else { - parts.push_back(result->data); - } - - json raw_data = result->data; // Access raw data for token ID - - for (const auto& r : parts) { - json item = r; - // Attach raw token ID for banning logic - if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; - items_to_buffer.push_back(item); - } - } else { - items_to_buffer.push_back(result->to_json()); - } - - // 2. Process items into buffer - for (const auto& item : items_to_buffer) { - token_buffer.push_back(item); - } - - print_debug_buffer(token_buffer); - - // 3. Check for Stop Phrases (Strings & Regex) - std::string buffer_text = ""; - std::vector token_offsets; - - for (const auto& item : token_buffer) { - token_offsets.push_back(buffer_text.length()); - buffer_text += get_content_str(item); - } - - std::string buffer_lower = to_lower_str(buffer_text); - - size_t match_pos = std::string::npos; - std::string detected_phrase = ""; - - // A. Check Strings (Case Insensitive) - for (const auto& phrase : stop_phrases) { - std::string target_lower = to_lower_str(phrase); - size_t pos = buffer_lower.find(target_lower); - if (pos != std::string::npos) { - if (match_pos == std::string::npos || pos < match_pos) { - match_pos = pos; - detected_phrase = phrase; - } - } - } - - // B. Check Regex - for (size_t i = 0; i < stop_regexes.size(); ++i) { - std::smatch match; - // We search the raw buffer_text - if (std::regex_search(buffer_text, match, stop_regexes[i])) { - size_t pos = match.position(0); - if (match_pos == std::string::npos || pos < match_pos) { - match_pos = pos; - detected_phrase = "REGEX:" + regex_patterns[i]; - } - } - } - - if (match_pos != std::string::npos) { - std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; - - // Find the guilty token - size_t split_index = 0; - bool found_split = false; - for (size_t i = 0; i < token_offsets.size(); ++i) { - size_t token_start = token_offsets[i]; - std::string content = get_content_str(token_buffer[i]); - size_t token_end = token_start + content.length(); - - if (token_end > match_pos) { - split_index = i; - found_split = true; - break; - } - } - - if (found_split) { - // 1. Construct prompt from good tokens (DO NOT FLUSH) - std::string temp_prompt_suffix = ""; - std::deque good_tokens; - - for (size_t i = 0; i < split_index; ++i) { - json& item = token_buffer[i]; - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - temp_prompt_suffix += get_content_str(item); - good_tokens.push_back(item); - } - - // 2. Identify Guilty Token & Add to Bans - json& guilty_item = token_buffer[split_index]; - int guilty_token_id = get_token_id(guilty_item); - - if (guilty_token_id == -1) { - std::string content = get_content_str(guilty_item); - auto tokens = ctx_server.tokenize(content, false); - if (!tokens.empty()) guilty_token_id = tokens[0]; - } - - if (guilty_token_id != -1) { - // Check if we are banning a different slot than before - if (ban_slot_index != (int)split_index) { - current_step_bans.clear(); - ban_slot_index = (int)split_index; - } - - current_step_bans.insert(guilty_token_id); - std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; - - // 3. Cancel current task - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - - // 4. FIX STEP: Generate 1 token with ALL current bans - json fix_data = data; - fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; - fix_data["n_predict"] = 1; - - // Robust logit_bias handling - if (!fix_data.contains("logit_bias")) { - fix_data["logit_bias"] = json::array(); - } - - if (fix_data["logit_bias"].is_array()) { - for (int banned_id : current_step_bans) { - fix_data["logit_bias"].push_back(json::array({banned_id, ban_bias})); - } - } else if (fix_data["logit_bias"].is_object()) { - for (int banned_id : current_step_bans) { - fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; - } - } - - std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; - - int id_fix = ctx_server.queue_tasks.get_new_id(); - *active_task_id = id_fix; // Update shared state for fix task - ctx_server.queue_results.add_waiting_task_id(id_fix); - - std::vector fix_inputs = tokenize_input_prompts( - llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true - ); - - // Use helper - post_task_with_params(id_fix, fix_data, fix_inputs[0]); - - // Wait for the fix token - std::unordered_set fix_ids = { id_fix }; - server_task_result_ptr fix_result = nullptr; - while (!fix_result) { - fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); - if (!fix_result && !sink.is_writable()) { - request_cancel(id_fix); - ctx_server.queue_results.remove_waiting_task_id(id_fix); - return false; - } - } - ctx_server.queue_results.remove_waiting_task_id(id_fix); - - // Check for error in fix result - if (fix_result->is_error()) { - std::cout << "Debug: Fix task failed with error." << std::endl; - sse(fix_result->to_json()); - return false; - } - - // Process fix token - json fix_token_json; - json raw_fix = fix_result->data; - - // Use format_partial_response_oaicompat for fix token too - if (oaicompat) { - std::vector parts = format_partial_response_oaicompat(*fix_result, completion_id); - if (!parts.empty()) fix_token_json = parts[0]; - } else { - fix_token_json = fix_result->data; - } - - if (raw_fix.contains("token")) fix_token_json["__raw_token_id"] = raw_fix["token"]; - - std::string fix_content = get_content_str(fix_token_json); - - // 5. RESUME STEP: Continue generation normally - json resume_data = data; - bool stop_after_fix = false; - - if (original_n_predict > 0) { - int pending = good_tokens.size() + 1; - if (total_tokens_streamed + pending >= original_n_predict) { - stop_after_fix = true; - } else { - resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); - } - } - - if (stop_after_fix) { - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - - while (!token_buffer.empty()) { - json& item = token_buffer.front(); - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - if (!sse(item)) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - total_tokens_streamed++; - token_buffer.pop_front(); - } - successful_completion = true; - goto cleanup; - } - - resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + fix_content; - - current_task_id = ctx_server.queue_tasks.get_new_id(); - *active_task_id = current_task_id; // Update shared state for resume task - ctx_server.queue_results.add_waiting_task_id(current_task_id); - - std::vector resume_inputs = tokenize_input_prompts( - llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true - ); - - // Use helper - post_task_with_params(current_task_id, resume_data, resume_inputs[0]); - - // 6. Update Buffer: Good Tokens + Fix Token - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - - // REMOVED continue; to allow flush logic to run - } - } - } - - // 4. Standard Flush Logic - bool should_flush_all = result->is_stop() || result->is_error(); - - if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { - while (!token_buffer.empty()) { - if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) { - break; - } - - json& item_to_send = token_buffer.front(); - if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); - - current_prompt_str += get_content_str(item_to_send); - - // SMART BAN CLEARING LOGIC - if (ban_slot_index != -1) { - if (0 == ban_slot_index) { - // We are flushing the slot that had bans. - // This means it's now accepted (or we are forced to flush). - current_step_bans.clear(); - ban_slot_index = -1; - } else { - // We are flushing a preceding token. - // The banned slot shifts left. - ban_slot_index--; - } - } - - if (!sse(item_to_send)) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - total_tokens_streamed++; - token_buffer.pop_front(); - - if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - successful_completion = true; - goto cleanup; - } - } - } - - if (result->is_error()) { - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - if (result->is_stop()) { - successful_completion = true; - break; - } - } - } - - cleanup: - bool ok = true; - if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string done_message = "data: [DONE]\n\n"; - LOG_VERBOSE("data stream", { {"to_send", done_message} }); - if (!sink.write(done_message.c_str(), done_message.size())) { - ok = false; - } - } - sink.done(); - - // Cleanup the active task ID (which might be different from id_task in slow path) - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - - return ok; - } catch (const std::exception& e) { - // Catch any exceptions to prevent crashing the server - std::cerr << "Exception in streaming handler: " << e.what() << std::endl; - sse(json{{"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}}}); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } catch (...) { - std::cerr << "Unknown exception in streaming handler" << std::endl; - sse(json{{"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}}}); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } - }; - - auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { - // Cancel the currently active task ID - int id_to_cancel = *active_task_id; - request_cancel(id_to_cancel); - ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - // Execute the complex logic - buffer_and_check_string_ban_and_rewind_logic(); - } - }; - - - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - auto data = json::parse(req.body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); - }; - - const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request& req, httplib::Response& res) { - auto body = json::parse(req.body); - json data = oaicompat_chat_params_parse(body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_COMPLETION); - }; - - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }; - - - - const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); - std::vector files; - json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files); - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_CHAT); - }; - - const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = anthropic_params_from_json( - ctx_server.model, - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_ANTHROPIC); - }; - - const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - std::vector files; - json body = json::parse(req.body); - - // Parse the Anthropic request (max_tokens is not required for count_tokens) - json body_parsed = anthropic_params_from_json( - ctx_server.model, - body, - ctx_server.oai_parser_opt, - files); - - json prompt = body_parsed.at("prompt"); - llama_tokens tokens = tokenize_mixed(llama_get_vocab(ctx_server.ctx), prompt, true, true); - - res_ok(res, {{"input_tokens", static_cast(tokens.size())}}); - return res; - }; - - // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - auto body = json::parse(req.body); - std::vector files; // dummy, unused - json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files); - res_ok(res, { { "prompt", std::move(data.at("prompt")) } }); - }; - - const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - const int id_task = ctx_server.queue_tasks.get_new_id(); - server_tokens token; // dummy tokens - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, data, true, false, std::move(token)); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::vector tokens; - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - tokens = ctx_server.tokenize(body.at("content"), add_special); - } - const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const std::vector tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); - } - - const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }; - - const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) { - if (!ctx_server.params_base.embedding) { - res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } - else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } - else { - res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } - else if (format != "float") { - res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - auto vocab = llama_get_vocab(ctx_server.ctx); - auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true); - for (const auto& tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.embd_normalize = embd_normalize; - task.embedding = true; // probably not needed - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); - - // collect results - if (all_results.is_terminated) { - llama_decode_stop(); - return; // connection is closed - } - else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } - else { - for (auto& res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res_ok(res, root); - - }; - - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); - }; - - const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); - }; - - - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - json result = json::array(); - for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { - auto & la = ctx_server.lora_adapters[i]; - result.push_back({ - {"id", i}, - {"path", la.path}, - {"scale", la.scale}, - }); - } - res.set_content(result.dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.lora_adapters.size(); - - // clear existing value - for (auto & la : ctx_server.lora_adapters) { - la.scale = 0.0f; - } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.lora_adapters[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - server_task task; - task.type = SERVER_TASK_TYPE_SET_LORA; - const int id_task = ctx_server.queue_tasks.post(std::move(task)); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - res.set_content(result.data.dump(), "application/json"); - res.status = 200; // HTTP OK - }; - - const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - json response = json::array(); - - try { - for (const auto& entry : fs::directory_iterator(params.slot_save_path)) { - if (!entry.is_regular_file() || entry.file_size() < 12) { - continue; - } - - std::ifstream file(entry.path(), std::ios::binary); - if (!file) continue; - - uint32_t magic, version, n_token_count; - file.read(reinterpret_cast(&magic), sizeof(magic)); - file.read(reinterpret_cast(&version), sizeof(version)); - file.read(reinterpret_cast(&n_token_count), sizeof(n_token_count)); - - if (magic != LLAMA_STATE_SEQ_MAGIC || - version != LLAMA_STATE_SEQ_VERSION || - entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) { - continue; - } - - std::vector tokens(n_token_count); - file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); - - //C++17 is not modern enough to have a nice and portable way to get the mtime of a file - //so the following seems to be needed - auto ftime = fs::last_write_time(entry.path()); - auto system_time = std::chrono::time_point_cast( - ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now() - ); - std::time_t c_time = std::chrono::system_clock::to_time_t(system_time); - std::tm tm_struct; - #if defined(_WIN32) - localtime_s(&tm_struct, &c_time); - #else - localtime_r(&c_time, &tm_struct); - #endif - std::ostringstream oss; - oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S"); - auto str_time = oss.str(); - - - response.push_back({ - {"filename", entry.path().filename().string()}, - {"filesize", entry.file_size()}, - {"mtime", str_time}, - {"token_count", n_token_count}, - {"prompt", tokens_to_str(ctx_server.ctx, tokens)} - }); - } - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - json response = json::array(); - for (server_slot & slot : ctx_server.slots) { - response.push_back({ - {"slot_id", slot.id}, - {"token_count", slot.cache_tokens.size()}, - {"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) } - }); - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - - const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - json response; - namespace fs = std::filesystem; - - try { - const json body = json::parse(req.body); - const std::string filename_str = body.at("filename"); - - // prevent directory traversal attacks - if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) { - res.status = 400; - response = {{"error", "Invalid filename format."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str); - - if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) { - res.status = 404; - response = {{"error", "File not found."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - if (fs::remove(file_to_delete)) { - response = { - {"status", "deleted"}, - {"filename", filename_str} - }; - } else { - res.status = 500; - response = {{"error", "Failed to delete the file."}}; - } - } catch (const json::parse_error& e) { - res.status = 400; - response = {{"error", "Invalid JSON request body."}}; - } catch (const json::out_of_range& e) { - res.status = 400; - response = {{"error", "Missing 'filename' key in request body."}}; - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - json response; - namespace fs = std::filesystem; - - try { - const json body = json::parse(req.body); - const std::string old_filename_str = body.at("old_filename"); - const std::string new_filename_str = body.at("new_filename"); - - if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos || - new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) { - res.status = 400; - response = {{"error", "Invalid filename format."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str; - const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str; - - if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) { - res.status = 404; - response = {{"error", "Source file not found."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - if (fs::exists(new_path)) { - res.status = 409; - response = {{"error", "Destination filename already exists."}}; - res.set_content(response.dump(), "application/json; charset=utf-8"); - return; - } - - std::error_code ec; - fs::rename(old_path, new_path, ec); - - if (ec) { - res.status = 500; - response = {{"error", "Failed to rename file: " + ec.message()}}; - } else { - response = { - {"status", "renamed"}, - {"old_filename", old_filename_str}, - {"new_filename", new_filename_str} - }; - } - - } catch (const json::parse_error& e) { - res.status = 400; - response = {{"error", "Invalid JSON request body."}}; - } catch (const json::out_of_range& e) { - res.status = 400; - response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}}; - } catch (const std::exception& e) { - res.status = 500; - response = {{"error", e.what()}}; - } - - res.set_content(response.dump(), "application/json; charset=utf-8"); - }; - - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; -#ifdef SQLITE3_MODERN_CPP_SUPPORT - const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) { - res.set_content( - json{{"version", 4}, - {"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(), - "application/json" - ); - }; -#else - const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void { - res.set_content( - json{{"version", 4}, - {"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(), - "application/json" - ); - }; -#endif - -#ifdef SQLITE3_MODERN_CPP_SUPPORT - auto db_handler = [db_handle](auto func) { - return [func, db_handle](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", "*"); - try { - const json body = !req.body.empty() ? json::parse(req.body) : json::object(); - func(*db_handle, body, req, res); - } catch(const std::exception& e) { - res.status = 500; - res.set_content( - json{{"ok", false}, {"message", e.what()}}.dump(), - "application/json" - ); - } - }; - }; -#else - auto db_handler = [db_handle](auto func) { - return [func, db_handle](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", "*"); - res.status = 500; - res.set_content( - json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(), - "application/json" - ); - }; - }; -#endif - - const auto normalize_store_name = [](const std::string& storeName) { - if(storeName.empty()) return std::string("sessions"); - - std::string normalized; - normalized.reserve(storeName.size()); - - for(char c : storeName) { - if(std::isalpha(static_cast(c))) { - normalized.push_back(std::tolower(static_cast(c))); - } - } - - return normalized.empty() ? "sessions" : normalized; - }; - - const auto get_key_string = [](const json& j) { - return j.is_string() ? j.get() : j.dump(); - }; - - - const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - std::string data; - const std::string store = normalize_store_name(body["storeName"]); - db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data; - if(data.empty()) { - res.status = 404; - res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json"); - } else { - json response{{"ok", true}}; - response["result"] = (store == "names") ? json(data) : json::parse(data); - res.set_content(response.dump(), "application/json"); - } - }); - - const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - const std::string store = normalize_store_name(body["storeName"]); - const std::string data = (store == "names") ? body["data"].get() : body["data"].dump(); - db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data; - res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json"); - }); - - const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) { - db.db << "UPDATE names SET data = ? WHERE key = ?" - << body["newName"].get() - << get_key_string(body["key"]); - res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json"); - }); - - const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >> - [&](const std::string& key, const std::string& data) { - result[key] = json::parse(data); - }; - res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); - }); - - const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) { - result[key] = data; - }; - res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); - }); - - const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { - db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?" - << get_key_string(body["key"]); - res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json"); - }); - - const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "VACUUM"; - res.set_content(json{"ok", true}.dump(), "application/json"); - }); - - const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) { - json result = json::object(); - db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) { - result[id] = config; - }; - res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json"); - }); - - const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) { - std::string data; - if (body["duration"].is_null()) { - db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get() >> data; - } - else { - db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get() << body["db_load"].get() >> data; - } - json response{{"ok", true}}; - response["result"] = json::parse(data); - res.set_content(response.dump(), "application/json"); - }); - - const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) { - db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get() + "\",\"column\": \"" + body["column"].get() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}')"; - res.set_content(json{"ok", true}.dump(), "application/json"); - }); - - const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) { - std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}"; - db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')"; - res.set_content(json{{"ok", true}}.dump(), "application/json"); - }); - - // - // Router - // - if (params.webui == COMMON_WEBUI_NONE) { - LLAMA_LOG_INFO("Web UI is disabled\n"); - } - else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - svr->set_base_dir(params.public_path); - } - - { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point("/", params.public_path); - if (!is_found) { - GGML_ABORT("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } - } - else { - - // using embedded static index.html - svr->Get("/", [params](const httplib::Request& req, httplib::Response& res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } - else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - if (params.webui == COMMON_WEBUI_AUTO) { - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - else if (params.webui == COMMON_WEBUI_LLAMACPP) { - res.set_content(reinterpret_cast(index_llamacpp_html_gz), index_llamacpp_html_gz_len, "text/html; charset=utf-8"); - } - else { - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - } - return false; - }); - } - } - } - // register API routes - svr->Get ("/health", handle_health); - svr->Get ("/metrics", handle_metrics); - svr->Get ("/props", handle_props); - svr->Get("/v1/props", handle_props_simple); - svr->Get ("/v1/models", handle_models); - svr->Post("/completion", handle_completions); // legacy - svr->Post("/completions", handle_completions); // legacy - svr->Post("/v1/completions", handle_completions_oai); - svr->Post("/chat/completions", handle_chat_completions); - svr->Post("/v1/chat/completions", handle_chat_completions); - svr->Post("/v1/messages", handle_anthropic_messages); - svr->Post("/v1/messages/count_tokens", handle_anthropic_count_tokens); - svr->Post("/infill", handle_infill); - svr->Post("/embedding", handle_embeddings); // legacy - svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings_oai); - svr->Post("/tokenize", handle_tokenize); - svr->Post("/detokenize", handle_detokenize); - svr->Post("/apply-template", handle_apply_template); - // LoRA adapters hotswap - svr->Get ("/lora-adapters", handle_lora_adapters_list); - svr->Post("/lora-adapters", handle_lora_adapters_apply); - // Save & load slots - svr->Get ("/slots", handle_slots); - svr->Get ("/slots/list", list_slot_prompts); - if (!params.slot_save_path.empty()) { - // these endpoints rely on slot_save_path existing - svr->Post("/slots/:id_slot", handle_slots_action); - svr->Get ("/list", list_saved_prompts); - svr->Post("/delete_prompt", delete_saved_prompt); - svr->Post("/rename_prompt", rename_saved_prompt); - - } - - svr->Get ("/version", handle_version); - if (!params.sql_save_file.empty()) { - // these endpoints rely on sql_save_file existing - svr->Post("/load", handle_load); - svr->Post("/save", handle_save); - svr->Post("/rename", handle_rename); - svr->Post("/all", handle_all); - svr->Post("/sessions", handle_sessions); - svr->Get ("/sessions", handle_sessions); - svr->Post("/delete", handle_delete); - //VACUUM is there for the extension but does not require the extension - svr->Get ("/vacuum", handle_vacuum); -#ifdef SQLITE3_MODERN_CPP_SUPPORT - if (sqlite_extension_loaded) { - svr->Get ("/zstd_get_configs", handle_zstd_get_configs); - svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance); - svr->Post("/zstd_enable_transparent", handle_zstd_enable); - svr->Post("/zstd_update_transparent", handle_zstd_config_update); - } -#endif - } - // - // Start the server - // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - - LOG_INFO("HTTP server listening", log_data); - - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; - } - - return 0; - }); - - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); - - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - - ctx_server.queue_tasks.start_loop(); - - svr->stop(); - t.join(); - - llama_backend_free(); - - return 0; -} +#pragma warning(disable : 4996) +#include "server-context.h" +#include "server-common.h" +#include "chat.h" + +#include "common.h" +#include "speculative.h" +#include "mtmd.h" +#include "sampling.h" +#include "llama.h" +#include "llama-vocab.h" +#include +#include +#include + +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" + + +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif + +#include +#include "index.html.gz.hpp" +#include "index_llamacpp.html.gz.hpp" +#include "loading.html.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef SQLITE3_MODERN_CPP_SUPPORT +#include + +struct DatabaseHandle { + sqlite::database db; + + DatabaseHandle(const std::string& path) : db(path) { + db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; + } +}; +#endif + +using json = nlohmann::ordered_json; +namespace fs = std::filesystem; +constexpr int HTTP_POLLING_SECONDS = 1; + +bool server_verbose = false; +bool server_log_json = true; + + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + + +static inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + + +inline std::string get_model_name(std::string path) +{ + std::string filename = path.substr(path.find_last_of("/\\") + 1); + return filename; +}; + + +static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) { + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason = "length"; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + + json choices = + streaming ? json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}} }) + : json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}} }); + + std::time_t t = std::time(0); + + json res = json{ + {"choices", choices}, + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + }}, + {"id", completion_id} + }; + + if (server_verbose) { + res["__verbose"] = result; + } + + if (result.contains("completion_probabilities")) { + res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + } + + return res; +} + +// return value is vector as there is one case where we might need to generate two responses +static std::vector format_partial_response_oaicompat(server_task_result task_result, const std::string& completion_id) { + json result = task_result.data; + std::cout << result.dump(4) << std::endl; + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + return std::vector({ result }); + } + + bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; + std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + if (stopped_limit) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({ json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}} }); + } + else { + if (first) { + if (content.empty()) { + choices = json::array({ json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}} }); + } + else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{ {"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} }; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} }; + + return std::vector({ initial_ret, second_ret }); + } + } + else { + // Some idiosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({ json::object() }); + } + + choices = json::array({ json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + } }); + } + } + + json ret = json{ + {"choices", choices}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + + if (task_result.timings.prompt_n != -1) { + ret.push_back({ "timings", task_result.timings.to_json() }); + } + + // + if (!finish_reason.empty()) { + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + ret.push_back({ "usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + } }); + } + + return std::vector({ ret }); +} + + +static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto& elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } + else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + json res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + + return res; +} + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health" || req.path == "/v1/completions") { + return; + } + + LOG_INFO("request", { + {"remote_addr", req.remote_addr}, + {"remote_port", req.remote_port}, + {"status", res.status}, + {"method", req.method}, + {"path", req.path}, + {"params", req.params}, + }); + + LOG_VERBOSE("request", { + {"request", req.body}, + {"response", res.body}, + }); +} + +// generator-like API for server responses, support pooling connection state and aggregating results +struct server_response_reader { + std::unordered_set id_tasks; + server_context& ctx_server; + size_t received_count = 0; + bool cancelled = false; + + server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector&& tasks) { + id_tasks = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + bool has_next() { + return !cancelled && received_count < id_tasks.size(); + } + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function& should_stop) { + while (true) { + server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } + else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here + } + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + + batch_response wait_for_all(const std::function& should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->error) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; + } + + void stop() { + ctx_server.queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto& id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + ctx_server.queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + ctx_server.queue_tasks.post(std::move(cancel_tasks), true); + } + else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } + } +}; + +auto res_err = [](httplib::Response& res, json error_data) { + json final_response{ {"error", error_data} }; + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); +}; + +auto res_ok = [](httplib::Response& res, const json& data) { + res.set_content(data.dump(), "application/json; charset=utf-8"); + res.status = 200; +}; + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +int main(int argc, char ** argv) { +#if SERVER_VERBOSE != 1 + log_disable(); +#endif + // own arguments required by this example + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + gpt_params_print_usage(argc, argv, params); + return 1; + } + + // parse arguments from environment variables + gpt_params_parse_from_env(params); + + // TODO: not great to use extern vars + server_log_json = params.log_json; + server_verbose = params.verbosity > 0; + + + // struct that contains llama context and inference + server_context ctx_server; + + if (!params.system_prompt.empty()) { + ctx_server.system_prompt_set(params.system_prompt); + } + + if (params.model_alias == "unknown") { + params.model_alias = params.model; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INFO("build info", { + {"build", LLAMA_BUILD_NUMBER}, + {"commit", LLAMA_COMMIT} + }); + + LOG_INFO("system info", { + {"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); + + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}}); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INFO("Running without SSL", {}); + svr.reset(new httplib::Server()); + } +#else + svr.reset(new httplib::Server()); +#endif + + std::atomic state{SERVER_STATE_LOADING_MODEL}; + + svr->set_default_headers({{"Server", "ik_llama.cpp"}}); + + svr->set_logger(log_server_request); + + + + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + std::string message; + try { + std::rethrow_exception(std::move(ep)); + } catch (std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_VERBOSE("Got exception", formatted_error); + res_err(res, formatted_error); + }); + + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + } + // for other error codes, we skip processing here because it's already done by res_err() + }); + + // set timeouts and change hostname and port + svr->set_read_timeout (params.timeout_read); + svr->set_write_timeout(params.timeout_write); + + if (!svr->bind_to_port(params.hostname, params.port)) { + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); + return 1; + } + + std::unordered_map log_data; + + log_data["hostname"] = params.hostname; + log_data["port"] = std::to_string(params.port); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); + } else if (params.api_keys.size() > 1) { + log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; + } + + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + ctx_server.cache_ram_n_min = params.cache_ram_n_min; + ctx_server.cache_ram_similarity = params.cache_ram_similarity; +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handle = std::make_shared(params.sql_save_file); + bool sqlite_extension_loaded = false; + if (!params.sqlite_zstd_ext_file.empty()) { + auto* conn = db_handle->db.connection().get(); + sqlite3_enable_load_extension(conn, 1); + char* errmsg = nullptr; + const int rc = sqlite3_load_extension( + conn, + params.sqlite_zstd_ext_file.c_str(), + nullptr, + &errmsg + ); + if(rc != SQLITE_OK) { + const std::string err = errmsg ? errmsg : "Unknown extension error"; + sqlite3_free(errmsg); + LOG_WARNING("Failed to load extension", {{"err", err}}); + } + else { + sqlite_extension_loaded = true; + } + sqlite3_enable_load_extension(conn, 0); + } +#else + auto db_handle = false; +#endif + // load the model + if (!ctx_server.load_model(params)) { + state.store(SERVER_STATE_ERROR); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + } + + LOG_INFO("model loaded", {}); + + const auto model_meta = ctx_server.model_meta(); + + // print sample chat example to make it clear which template is used + + LOG_INFO("chat template", { + {"chat_template", common_chat_templates_source(ctx_server.chat_templates.get())}, + }); + + LOG_INFO("chat template", { + {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, {}).c_str() + }, + {"built_in", params.chat_template.empty()}, + }); + // + // Middlewares + // + + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/v1/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (params.api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { + return true; // API key is valid + } + } + + auth_header = req.get_header_value("X-Api-Key"); + + if (std::find(params.api_keys.begin(), params.api_keys.end(), auth_header) != params.api_keys.end()) { + return true; // API key is valid + } + + // API key is invalid or not provided + res.status = 401; + res.set_content( + (json { + {"error", { + {"message", "Invalid API Key"}, + {"type", "authentication_error"}, + {"code", 401} + }} + }).dump(-1, ' ', false, json::error_handler_t::replace), + "application/json; charset=utf-8" + ); + LOG_WARNING("Unauthorized: Invalid API Key\n", {}); + return false; + }; + + auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } + else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } + else { + res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } + return false; + } + return true; + }; + + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + // + // Route handlers (or controllers) + // + + const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; + task.id_target = -1; + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + const int n_idle_slots = result.data.at("idle"); + const int n_processing_slots = result.data.at("processing"); + + json health = { + {"status", "ok"}, + {"slots_idle", n_idle_slots}, + {"slots_processing", n_processing_slots} + }; + + res.status = 200; // HTTP OK + if (params.endpoint_slots && req.has_param("include_slots")) { + health["slots"] = result.data.at("slots"); + } + + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; + } + case SERVER_STATE_LOADING_MODEL: + { + res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } break; + case SERVER_STATE_ERROR: + { + res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); + } break; + } + }; + + const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + if (!params.endpoint_slots) { + res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + res.set_content(result.data.at("slots").dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + if (!params.endpoint_metrics) { + res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + task.data.push_back({{"reset_bucket", true}}); + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(std::move(task)); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + json data = result.data; + + const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); + const uint64_t t_prompt_processing = data.at("t_prompt_processing"); + + const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); + const uint64_t t_tokens_generation = data.at("t_tokens_generation"); + + const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) data.at("n_tokens_predicted_total")} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + },{ + {"name", "kv_cache_usage_ratio"}, + {"help", "KV-cache usage. 1 means 100 percent usage."}, + {"value", 1. * kv_cache_used_cells / params.n_ctx} + },{ + {"name", "kv_cache_tokens"}, + {"help", "KV-cache tokens."}, + {"value", (uint64_t) data.at("kv_cache_tokens_count")} + },{ + {"name", "requests_processing"}, + {"help", "Number of request processing."}, + {"value", (uint64_t) data.at("processing")} + },{ + {"name", "requests_deferred"}, + {"help", "Number of request deferred."}, + {"value", (uint64_t) data.at("deferred")} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + const int64_t t_start = data.at("t_start"); + res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + + res.set_content(prometheus.str(), "text/plain; version=0.0.4"); + res.status = 200; // HTTP OK + }; + + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_err(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + std::string template_key = "tokenizer.chat_template", curr_tmpl; + int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); + } + } + json data = { + { "system_prompt", ctx_server.system_prompt.c_str() }, + { "model_alias", ctx_server.params_base.model_alias }, + { "model_path", ctx_server.params_base.model}, + { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_name", get_model_name(ctx_server.params_base.model)}, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), /* special= */ true)}, + { "model_path", ctx_server.params_base.model }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "n_ctx", ctx_server.n_ctx } + + }; + + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_props_simple = [&ctx_server](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + int n_past = 0; + int slot_id = 0; + for (server_slot& slot : ctx_server.slots) { + if (slot.n_past > n_past) { + n_past = slot.n_past; + slot_id = slot.id; + } + } + json data = { + { "model_name", get_model_name(ctx_server.params_base.model)}, + { "model_path", ctx_server.params_base.model }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "n_ctx", ctx_server.n_ctx } + }; + res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + + + +// handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_impl = [&ctx_server, ¶ms]( + server_task_type type, + json& data, + const std::vector& files, + const std::function& is_connection_closed, + httplib::Response& res, + oaicompat_type oaicompat) -> void { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + // ---------------------------------------------------------------- + // 1. Regex Validation + // ---------------------------------------------------------------- + auto validate_regex_list = [&](const std::string& field_name) -> std::string { + if (data.contains(field_name) && data[field_name].is_array()) { + for (const auto& val : data[field_name]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s); + } catch (const std::regex_error& e) { + return s; + } + } + } + } + } + return ""; + }; + + std::string invalid_re = validate_regex_list("banned_regex"); + if (invalid_re.empty()) invalid_re = validate_regex_list("banned_regex_case_insensitive"); + + if (!invalid_re.empty()) { + res_err(res, format_error_response("Invalid regex: " + invalid_re, ERROR_TYPE_INVALID_REQUEST)); + return; + } + + const auto completion_id = gen_chatcmplid(); + + // Process prompt / inputs + std::vector inputs; + try { + const auto& prompt = data.at("prompt"); + if (oaicompat && ctx_server.mctx != nullptr) { + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } + else { + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + } + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // ---------------------------------------------------------------- + // Check if we need the complex "Banned String" logic + // Only enable if the lists are present AND contain actual strings. + // ---------------------------------------------------------------- + auto list_has_content = [&](const std::string& key) { + if (data.contains(key) && data[key].is_array()) { + for (const auto& item : data[key]) { + if (item.is_string() && !item.get().empty()) { + return true; + } + } + } + return false; + }; + + bool has_banned_content = list_has_content("banned_strings") || + list_has_content("banned_regex") || + list_has_content("banned_regex_case_insensitive"); + + if (!has_banned_content) { + // ---------------------------------------------------------------- + // PATH A: Standard Logic (server_response_reader) + // ---------------------------------------------------------------- + + // need to store the reader as a pointer, so that it won't be destroyed when the handle returns + // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() + const auto rd = std::make_shared(ctx_server); + + try { + std::vector tasks; + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.data = data; + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); + tasks.push_back(std::move(task)); + } + + rd->post_tasks(std::move(tasks)); + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + bool stream = json_value(data, "stream", false); + if (!stream) { + // non-stream, wait for the results + auto all_results = rd->wait_for_all(is_connection_closed); + if (all_results.is_terminated) { + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed + } + else if (all_results.error) { + res_err(res, all_results.error->to_json()); + return; + } + else { + json arr = json::array(); + for (auto& res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + if (oaicompat) { + arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); + } else { + arr.push_back(res->to_json()); + } + } + // if single request, return single object instead of array + res_ok(res, arr.size() == 1 ? arr[0] : arr); + } + } + else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd->next(is_connection_closed); + if (first_result == nullptr) { + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed + } + else if (first_result->is_error()) { + res_err(res, first_result->to_json()); + return; + } + else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // Prepare first result JSON (handling OAI format if needed) + std::vector first_result_parts; + if (oaicompat) { + first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); + } else { + first_result_parts.push_back(first_result->to_json()); + } + + const auto chunked_content_provider = [first_result_parts, rd, oaicompat, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { + const auto sse = [oaicompat, &sink](const json& res) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, res); + } + else { + return server_sent_event(sink, res); + } + }; + + // flush the first result parts + for (auto& part : first_result_parts) { + if (!part.empty()) { + if (!sse(part)) { + sink.done(); + return false; // sending failed, go to on_complete() + } + part.clear(); // mark as sent + } + } + + // receive subsequent results + auto result = rd->next([&sink] { return !sink.is_writable(); }); + if (result == nullptr) { + sink.done(); + return false; // connection is closed, go to on_complete() + } + + // send the results + bool ok = false; + if (result->is_error()) { + ok = sse(json{ { "error", result->to_json() } }); + sink.done(); + return false; // go to on_complete() + } + else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + + if (oaicompat) { + std::vector parts = format_partial_response_oaicompat(*result, completion_id); + for (const auto& part : parts) { + ok = sse(part); + if (!ok) break; + } + } else { + ok = sse(result->to_json()); + } + } + + if (!ok) { + sink.done(); + return false; // sending failed, go to on_complete() + } + + // check if there is more data + if (!rd->has_next()) { + if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } + sink.done(); + return false; // no more data, go to on_complete() + } + + // has next data, continue + return true; + }; + + auto on_complete = [rd](bool) { + rd->stop(); + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + + } else { + // ---------------------------------------------------------------- + // PATH B: Banned Content Logic (Slow Path with Buffering & Rewind) + // ---------------------------------------------------------------- + auto buffer_and_check_string_ban_and_rewind_logic = [&]() { + // Helper to mimic request_cancel using the task queue directly + auto request_cancel = [&ctx_server](int id_target) { + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_target; + std::vector tasks; + tasks.push_back(std::move(task)); + ctx_server.queue_tasks.post(std::move(tasks), true); + }; + + // Helper to post a completion task with correct OAI params + auto post_task_with_params = [&ctx_server, oaicompat, completion_id](int id_task, json& task_data, server_tokens& tokens) { + server_task task(SERVER_TASK_TYPE_COMPLETION); + task.id = id_task; + task.index = 0; + task.tokens = std::move(tokens); + task.data = task_data; + task.id_slot = json_value(task_data, "id_slot", -1); + + // Critical: Set OAI params so worker generates correct output format + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); + + std::vector tasks; + tasks.push_back(std::move(task)); + ctx_server.queue_tasks.post(std::move(tasks)); + }; + + const int id_task = ctx_server.queue_tasks.get_new_id(); + ctx_server.queue_results.add_waiting_task_id(id_task); + + // Use helper instead of request_completion + post_task_with_params(id_task, data, inputs[0]); + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // Non-streaming: wait for result (using pointer to avoid slicing) + std::unordered_set ids = { id_task }; + server_task_result_ptr result = nullptr; + + // Simple blocking wait + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && is_connection_closed()) { + request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + return; + } + } + + if (!result->is_error()) { + json result_json; + if (oaicompat) { + result_json = format_final_response_oaicompat(data, result->data, completion_id, false); + } else { + result_json = result->to_json(); + } + res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } + else { + res_err(res, result->to_json()); + } + ctx_server.queue_results.remove_waiting_task_id(id_task); + } + else { + // Shared state to track the currently running task ID across retries. + auto active_task_id = std::make_shared(id_task); + + // Capture 'data' by value to use as a template for retries + const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, send_done = params.send_done, data, request_cancel, post_task_with_params](size_t, httplib::DataSink& sink) mutable { + // Define sse here so it's visible to both try and catch blocks + const auto sse = [oaicompat, &sink](const json &res) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, res); + } else { + return server_sent_event(sink, res); + } + }; + + try { + bool successful_completion = false; + + // 1. Parse Configuration from Request + + // Banned Strings + std::vector stop_phrases; + if (data.contains("banned_strings") && data["banned_strings"].is_array()) { + for (const auto& val : data["banned_strings"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) stop_phrases.push_back(s); + } + } + } + + // Sort banned strings by length (descending) + std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { + return a.length() > b.length(); + }); + + // Banned Regex (Case Sensitive & Insensitive) + std::vector regex_patterns; // For buffer size calculation + std::vector stop_regexes; // Compiled regexes + + auto add_regex_list = [&](const std::string& field_name, bool case_insensitive) { + if (data.contains(field_name) && data[field_name].is_array()) { + for (const auto& val : data[field_name]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + auto flags = std::regex_constants::ECMAScript; + if (case_insensitive) flags |= std::regex_constants::icase; + stop_regexes.emplace_back(s, flags); + regex_patterns.push_back(s); + } + } + } + } + }; + + // We assume validation passed in handle_completions_impl, so no try-catch needed here + add_regex_list("banned_regex", false); + add_regex_list("banned_regex_case_insensitive", true); + + // Logit Bias Penalty (Default: -10000.0) + float ban_bias = -10000.0f; + if (data.contains("banned_bias") && data["banned_bias"].is_number()) { + ban_bias = data["banned_bias"].get(); + } + + // Manual Buffer Size + size_t manual_buffer_size = 0; + if (data.contains("banbuffer_size") && data["banbuffer_size"].is_number_unsigned()) { + manual_buffer_size = data["banbuffer_size"].get(); + } + + // Token Limit Tracking + int original_n_predict = -1; + if (data.contains("n_predict") && data["n_predict"].is_number_integer()) { + original_n_predict = data["n_predict"].get(); + } + int total_tokens_streamed = 0; + + // ============================================================ + // FAST PATH: No banned strings AND No regex -> No buffering + // ============================================================ + if (stop_phrases.empty() && stop_regexes.empty()) { + while (true) { + std::unordered_set ids = { *active_task_id }; + server_task_result_ptr result = nullptr; + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && !sink.is_writable()) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + + if (!result->is_error()) { + // Use format_partial_response_oaicompat to get the correct chunks + std::vector parts; + if (oaicompat) { + parts = format_partial_response_oaicompat(*result, completion_id); + } else { + parts.push_back(result->data); + } + + for (const auto& item : parts) { + if (!sse(item)) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + + if (result->is_stop()) { + successful_completion = true; + break; + } + } else { + sse(result->to_json()); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + } + } + // ============================================================ + // SLOW PATH: Buffering and Banning Logic + // ============================================================ + else { + // Calculate Buffer Size + size_t BUFFER_SIZE; + if (manual_buffer_size > 0) { + BUFFER_SIZE = manual_buffer_size; + } else { + size_t max_len = 0; + // Check strings + if (!stop_phrases.empty()) { + max_len = stop_phrases[0].length(); // First is longest due to sort + } + // Check regex patterns + for (const auto& pat : regex_patterns) { + if (pat.length() > max_len) max_len = pat.length(); + } + + // Default: Longest string/regex + 1 + BUFFER_SIZE = std::max((size_t)1, max_len + 1); + } + + // Initialize Buffer & State + std::deque token_buffer; + + int current_task_id = id_task; + + // Track bans specifically for the current "next token" to be generated. + std::set current_step_bans; + int ban_slot_index = -1; + + // Track the text that has been confirmed/sent to the client. + std::string current_prompt_str = ""; + if (data.contains("prompt") && data["prompt"].is_string()) { + current_prompt_str = data["prompt"].get(); + } + + // Helper to extract text content + auto get_content_str = [](const json& j) -> std::string { + if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) { + const auto& choice = j["choices"][0]; + if (choice.contains("delta") && choice["delta"].contains("content")) { + auto val = choice["delta"]["content"]; + if (val.is_string()) return val.get(); + } + } + if (j.contains("content")) { + auto val = j["content"]; + if (val.is_string()) return val.get(); + } + return ""; + }; + + // Helper to extract Token ID + auto get_token_id = [](const json& j) -> int { + if (j.contains("__raw_token_id")) return j["__raw_token_id"].get(); + if (j.contains("token")) return j["token"].get(); + if (j.contains("id")) return j["id"].get(); + return -1; + }; + + // Helper for case-insensitive search + auto to_lower_str = [](std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c){ return std::tolower(c); }); + return s; + }; + + // Helper to print buffer + auto print_debug_buffer = [&](const std::deque& buf) { + std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; + size_t print_len = std::max(buf.size(), BUFFER_SIZE); + for (size_t i = 0; i < print_len; ++i) { + if (i < buf.size()) { + std::string content = get_content_str(buf[i]); + std::string escaped; + for (char c : content) { + if (c == '\n') escaped += "\\n"; + else if (c == '"') escaped += "\\\""; + else escaped += c; + } + std::cout << "\"" << escaped << "\""; + } else { + std::cout << "\"\""; + } + if (i < print_len - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + }; + + while (true) { + // Ensure shared state matches current local state + *active_task_id = current_task_id; + + // 0. Check connection status explicitly + if (!sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + // Receive from the CURRENT task ID using pointer to avoid slicing + std::unordered_set ids = { current_task_id }; + server_task_result_ptr result = nullptr; + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && !sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + } + + std::vector items_to_buffer; + + if (!result->is_error()) { + // Use format_partial_response_oaicompat to get the correct chunks + std::vector parts; + if (oaicompat) { + parts = format_partial_response_oaicompat(*result, completion_id); + } else { + parts.push_back(result->data); + } + + json raw_data = result->data; // Access raw data for token ID + + for (const auto& r : parts) { + json item = r; + // Attach raw token ID for banning logic + if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; + items_to_buffer.push_back(item); + } + } else { + items_to_buffer.push_back(result->to_json()); + } + + // 2. Process items into buffer + for (const auto& item : items_to_buffer) { + token_buffer.push_back(item); + } + + print_debug_buffer(token_buffer); + + // 3. Check for Stop Phrases (Strings & Regex) + std::string buffer_text = ""; + std::vector token_offsets; + + for (const auto& item : token_buffer) { + token_offsets.push_back(buffer_text.length()); + buffer_text += get_content_str(item); + } + + std::string buffer_lower = to_lower_str(buffer_text); + + size_t match_pos = std::string::npos; + std::string detected_phrase = ""; + + // A. Check Strings (Case Insensitive) + for (const auto& phrase : stop_phrases) { + std::string target_lower = to_lower_str(phrase); + size_t pos = buffer_lower.find(target_lower); + if (pos != std::string::npos) { + if (match_pos == std::string::npos || pos < match_pos) { + match_pos = pos; + detected_phrase = phrase; + } + } + } + + // B. Check Regex + for (size_t i = 0; i < stop_regexes.size(); ++i) { + std::smatch match; + // We search the raw buffer_text + if (std::regex_search(buffer_text, match, stop_regexes[i])) { + size_t pos = match.position(0); + if (match_pos == std::string::npos || pos < match_pos) { + match_pos = pos; + detected_phrase = "REGEX:" + regex_patterns[i]; + } + } + } + + if (match_pos != std::string::npos) { + std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; + + // Find the guilty token + size_t split_index = 0; + bool found_split = false; + for (size_t i = 0; i < token_offsets.size(); ++i) { + size_t token_start = token_offsets[i]; + std::string content = get_content_str(token_buffer[i]); + size_t token_end = token_start + content.length(); + + if (token_end > match_pos) { + split_index = i; + found_split = true; + break; + } + } + + if (found_split) { + // 1. Construct prompt from good tokens (DO NOT FLUSH) + std::string temp_prompt_suffix = ""; + std::deque good_tokens; + + for (size_t i = 0; i < split_index; ++i) { + json& item = token_buffer[i]; + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + temp_prompt_suffix += get_content_str(item); + good_tokens.push_back(item); + } + + // 2. Identify Guilty Token & Add to Bans + json& guilty_item = token_buffer[split_index]; + int guilty_token_id = get_token_id(guilty_item); + + if (guilty_token_id == -1) { + std::string content = get_content_str(guilty_item); + auto tokens = ctx_server.tokenize(content, false); + if (!tokens.empty()) guilty_token_id = tokens[0]; + } + + if (guilty_token_id != -1) { + // Check if we are banning a different slot than before + if (ban_slot_index != (int)split_index) { + current_step_bans.clear(); + ban_slot_index = (int)split_index; + } + + current_step_bans.insert(guilty_token_id); + std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; + + // 3. Cancel current task + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + + // 4. FIX STEP: Generate 1 token with ALL current bans + json fix_data = data; + fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; + fix_data["n_predict"] = 1; + + // Robust logit_bias handling + if (!fix_data.contains("logit_bias")) { + fix_data["logit_bias"] = json::array(); + } + + if (fix_data["logit_bias"].is_array()) { + for (int banned_id : current_step_bans) { + fix_data["logit_bias"].push_back(json::array({banned_id, ban_bias})); + } + } else if (fix_data["logit_bias"].is_object()) { + for (int banned_id : current_step_bans) { + fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; + } + } + + std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; + + int id_fix = ctx_server.queue_tasks.get_new_id(); + *active_task_id = id_fix; // Update shared state for fix task + ctx_server.queue_results.add_waiting_task_id(id_fix); + + std::vector fix_inputs = tokenize_input_prompts( + llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true + ); + + // Use helper + post_task_with_params(id_fix, fix_data, fix_inputs[0]); + + // Wait for the fix token + std::unordered_set fix_ids = { id_fix }; + server_task_result_ptr fix_result = nullptr; + while (!fix_result) { + fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); + if (!fix_result && !sink.is_writable()) { + request_cancel(id_fix); + ctx_server.queue_results.remove_waiting_task_id(id_fix); + return false; + } + } + ctx_server.queue_results.remove_waiting_task_id(id_fix); + + // Check for error in fix result + if (fix_result->is_error()) { + std::cout << "Debug: Fix task failed with error." << std::endl; + sse(fix_result->to_json()); + return false; + } + + // Process fix token + json fix_token_json; + json raw_fix = fix_result->data; + + // Use format_partial_response_oaicompat for fix token too + if (oaicompat) { + std::vector parts = format_partial_response_oaicompat(*fix_result, completion_id); + if (!parts.empty()) fix_token_json = parts[0]; + } else { + fix_token_json = fix_result->data; + } + + if (raw_fix.contains("token")) fix_token_json["__raw_token_id"] = raw_fix["token"]; + + std::string fix_content = get_content_str(fix_token_json); + + // 5. RESUME STEP: Continue generation normally + json resume_data = data; + bool stop_after_fix = false; + + if (original_n_predict > 0) { + int pending = good_tokens.size() + 1; + if (total_tokens_streamed + pending >= original_n_predict) { + stop_after_fix = true; + } else { + resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); + } + } + + if (stop_after_fix) { + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + + while (!token_buffer.empty()) { + json& item = token_buffer.front(); + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + if (!sse(item)) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + total_tokens_streamed++; + token_buffer.pop_front(); + } + successful_completion = true; + goto cleanup; + } + + resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + fix_content; + + current_task_id = ctx_server.queue_tasks.get_new_id(); + *active_task_id = current_task_id; // Update shared state for resume task + ctx_server.queue_results.add_waiting_task_id(current_task_id); + + std::vector resume_inputs = tokenize_input_prompts( + llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true + ); + + // Use helper + post_task_with_params(current_task_id, resume_data, resume_inputs[0]); + + // 6. Update Buffer: Good Tokens + Fix Token + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + + // REMOVED continue; to allow flush logic to run + } + } + } + + // 4. Standard Flush Logic + bool should_flush_all = result->is_stop() || result->is_error(); + + if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { + while (!token_buffer.empty()) { + if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) { + break; + } + + json& item_to_send = token_buffer.front(); + if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); + + current_prompt_str += get_content_str(item_to_send); + + // SMART BAN CLEARING LOGIC + if (ban_slot_index != -1) { + if (0 == ban_slot_index) { + // We are flushing the slot that had bans. + // This means it's now accepted (or we are forced to flush). + current_step_bans.clear(); + ban_slot_index = -1; + } else { + // We are flushing a preceding token. + // The banned slot shifts left. + ban_slot_index--; + } + } + + if (!sse(item_to_send)) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + total_tokens_streamed++; + token_buffer.pop_front(); + + if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + successful_completion = true; + goto cleanup; + } + } + } + + if (result->is_error()) { + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + if (result->is_stop()) { + successful_completion = true; + break; + } + } + } + + cleanup: + bool ok = true; + if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string done_message = "data: [DONE]\n\n"; + LOG_VERBOSE("data stream", { {"to_send", done_message} }); + if (!sink.write(done_message.c_str(), done_message.size())) { + ok = false; + } + } + sink.done(); + + // Cleanup the active task ID (which might be different from id_task in slow path) + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + + return ok; + } catch (const std::exception& e) { + // Catch any exceptions to prevent crashing the server + std::cerr << "Exception in streaming handler: " << e.what() << std::endl; + sse(json{{"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}}}); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } catch (...) { + std::cerr << "Unknown exception in streaming handler" << std::endl; + sse(json{{"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}}}); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } + }; + + auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { + // Cancel the currently active task ID + int id_to_cancel = *active_task_id; + request_cancel(id_to_cancel); + ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }; + + // Execute the complex logic + buffer_and_check_string_ban_and_rewind_logic(); + } + }; + + + + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + auto data = json::parse(req.body); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); + }; + + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request& req, httplib::Response& res) { + auto body = json::parse(req.body); + json data = oaicompat_chat_params_parse(body); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_COMPLETION); + }; + + const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; + + res.set_content(models.dump(), "application/json; charset=utf-8"); + }; + + + + const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + auto body = json::parse(req.body); + std::vector files; + json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files); + handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_CHAT); + }; + + const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = anthropic_params_from_json( + ctx_server.model, + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_ANTHROPIC); + }; + + const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + json body = json::parse(req.body); + + // Parse the Anthropic request (max_tokens is not required for count_tokens) + json body_parsed = anthropic_params_from_json( + ctx_server.model, + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(llama_get_vocab(ctx_server.ctx), prompt, true, true); + + res_ok(res, {{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + + // same with handle_chat_completions, but without inference part + const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + auto body = json::parse(req.body); + std::vector files; // dummy, unused + json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files); + res_ok(res, { { "prompt", std::move(data.at("prompt")) } }); + }; + + const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + const int id_task = ctx_server.queue_tasks.get_new_id(); + server_tokens token; // dummy tokens + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, true, false, std::move(token)); + std::vector files; // dummy + handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + files, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + std::vector tokens; + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + tokens = ctx_server.tokenize(body.at("content"), add_special); + } + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const std::vector tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens); + } + + const json data = format_detokenized_response(content); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) { + if (!ctx_server.params_base.embedding) { + res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } + else if (body.contains("content")) { + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } + else { + res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } + else if (format != "float") { + res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + auto vocab = llama_get_vocab(ctx_server.ctx); + auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true); + for (const auto& tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.embd_normalize = embd_normalize; + task.embedding = true; // probably not needed + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.is_connection_closed); + + // collect results + if (all_results.is_terminated) { + llama_decode_stop(); + return; // connection is closed + } + else if (all_results.error) { + res_err(res, all_results.error->to_json()); + return; + } + else { + for (auto& res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res_ok(res, root); + + }; + + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + }; + + + const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { + json result = json::array(); + for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { + auto & la = ctx_server.lora_adapters[i]; + result.push_back({ + {"id", i}, + {"path", la.path}, + {"scale", la.scale}, + }); + } + res.set_content(result.dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + + const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { + const std::vector body = json::parse(req.body); + int max_idx = ctx_server.lora_adapters.size(); + + // clear existing value + for (auto & la : ctx_server.lora_adapters) { + la.scale = 0.0f; + } + + // set value + for (auto entry : body) { + int id = entry.at("id"); + float scale = entry.at("scale"); + if (0 <= id && id < max_idx) { + ctx_server.lora_adapters[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + server_task task; + task.type = SERVER_TASK_TYPE_SET_LORA; + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + res.set_content(result.data.dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + json response = json::array(); + + try { + for (const auto& entry : fs::directory_iterator(params.slot_save_path)) { + if (!entry.is_regular_file() || entry.file_size() < 12) { + continue; + } + + std::ifstream file(entry.path(), std::ios::binary); + if (!file) continue; + + uint32_t magic, version, n_token_count; + file.read(reinterpret_cast(&magic), sizeof(magic)); + file.read(reinterpret_cast(&version), sizeof(version)); + file.read(reinterpret_cast(&n_token_count), sizeof(n_token_count)); + + if (magic != LLAMA_STATE_SEQ_MAGIC || + version != LLAMA_STATE_SEQ_VERSION || + entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) { + continue; + } + + std::vector tokens(n_token_count); + file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); + + //C++17 is not modern enough to have a nice and portable way to get the mtime of a file + //so the following seems to be needed + auto ftime = fs::last_write_time(entry.path()); + auto system_time = std::chrono::time_point_cast( + ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now() + ); + std::time_t c_time = std::chrono::system_clock::to_time_t(system_time); + std::tm tm_struct; + #if defined(_WIN32) + localtime_s(&tm_struct, &c_time); + #else + localtime_r(&c_time, &tm_struct); + #endif + std::ostringstream oss; + oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S"); + auto str_time = oss.str(); + + + response.push_back({ + {"filename", entry.path().filename().string()}, + {"filesize", entry.file_size()}, + {"mtime", str_time}, + {"token_count", n_token_count}, + {"prompt", tokens_to_str(ctx_server.ctx, tokens)} + }); + } + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + json response = json::array(); + for (server_slot & slot : ctx_server.slots) { + response.push_back({ + {"slot_id", slot.id}, + {"token_count", slot.cache_tokens.size()}, + {"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) } + }); + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + + const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string filename_str = body.at("filename"); + + // prevent directory traversal attacks + if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str); + + if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) { + res.status = 404; + response = {{"error", "File not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::remove(file_to_delete)) { + response = { + {"status", "deleted"}, + {"filename", filename_str} + }; + } else { + res.status = 500; + response = {{"error", "Failed to delete the file."}}; + } + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'filename' key in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string old_filename_str = body.at("old_filename"); + const std::string new_filename_str = body.at("new_filename"); + + if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos || + new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str; + const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str; + + if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) { + res.status = 404; + response = {{"error", "Source file not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::exists(new_path)) { + res.status = 409; + response = {{"error", "Destination filename already exists."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + std::error_code ec; + fs::rename(old_path, new_path, ec); + + if (ec) { + res.status = 500; + response = {{"error", "Failed to rename file: " + ec.message()}}; + } else { + response = { + {"status", "renamed"}, + {"old_filename", old_filename_str}, + {"new_filename", new_filename_str} + }; + } + + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; +#ifdef SQLITE3_MODERN_CPP_SUPPORT + const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(), + "application/json" + ); + }; +#else + const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(), + "application/json" + ); + }; +#endif + +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + try { + const json body = !req.body.empty() ? json::parse(req.body) : json::object(); + func(*db_handle, body, req, res); + } catch(const std::exception& e) { + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", e.what()}}.dump(), + "application/json" + ); + } + }; + }; +#else + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(), + "application/json" + ); + }; + }; +#endif + + const auto normalize_store_name = [](const std::string& storeName) { + if(storeName.empty()) return std::string("sessions"); + + std::string normalized; + normalized.reserve(storeName.size()); + + for(char c : storeName) { + if(std::isalpha(static_cast(c))) { + normalized.push_back(std::tolower(static_cast(c))); + } + } + + return normalized.empty() ? "sessions" : normalized; + }; + + const auto get_key_string = [](const json& j) { + return j.is_string() ? j.get() : j.dump(); + }; + + + const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + std::string data; + const std::string store = normalize_store_name(body["storeName"]); + db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data; + if(data.empty()) { + res.status = 404; + res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json"); + } else { + json response{{"ok", true}}; + response["result"] = (store == "names") ? json(data) : json::parse(data); + res.set_content(response.dump(), "application/json"); + } + }); + + const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + const std::string store = normalize_store_name(body["storeName"]); + const std::string data = (store == "names") ? body["data"].get() : body["data"].dump(); + db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data; + res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json"); + }); + + const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "UPDATE names SET data = ? WHERE key = ?" + << body["newName"].get() + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json"); + }); + + const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >> + [&](const std::string& key, const std::string& data) { + result[key] = json::parse(data); + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) { + result[key] = data; + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?" + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json"); + }); + + const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "VACUUM"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) { + result[id] = config; + }; + res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json"); + }); + + const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string data; + if (body["duration"].is_null()) { + db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get() >> data; + } + else { + db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get() << body["db_load"].get() >> data; + } + json response{{"ok", true}}; + response["result"] = json::parse(data); + res.set_content(response.dump(), "application/json"); + }); + + const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) { + db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get() + "\",\"column\": \"" + body["column"].get() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}')"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}"; + db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')"; + res.set_content(json{{"ok", true}}.dump(), "application/json"); + }); + + // + // Router + // + if (params.webui == COMMON_WEBUI_NONE) { + LLAMA_LOG_INFO("Web UI is disabled\n"); + } + else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + svr->set_base_dir(params.public_path); + } + + { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = svr->set_mount_point("/", params.public_path); + if (!is_found) { + GGML_ABORT("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } + else { + + // using embedded static index.html + svr->Get("/", [params](const httplib::Request& req, httplib::Response& res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } + else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + if (params.webui == COMMON_WEBUI_AUTO) { + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + else if (params.webui == COMMON_WEBUI_LLAMACPP) { + res.set_content(reinterpret_cast(index_llamacpp_html_gz), index_llamacpp_html_gz_len, "text/html; charset=utf-8"); + } + else { + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + } + return false; + }); + } + } + } + // register API routes + svr->Get ("/health", handle_health); + svr->Get ("/metrics", handle_metrics); + svr->Get ("/props", handle_props); + svr->Get("/v1/props", handle_props_simple); + svr->Get ("/v1/models", handle_models); + svr->Post("/completion", handle_completions); // legacy + svr->Post("/completions", handle_completions); // legacy + svr->Post("/v1/completions", handle_completions_oai); + svr->Post("/chat/completions", handle_chat_completions); + svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/v1/messages", handle_anthropic_messages); + svr->Post("/v1/messages/count_tokens", handle_anthropic_count_tokens); + svr->Post("/infill", handle_infill); + svr->Post("/embedding", handle_embeddings); // legacy + svr->Post("/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); + svr->Post("/tokenize", handle_tokenize); + svr->Post("/detokenize", handle_detokenize); + svr->Post("/apply-template", handle_apply_template); + // LoRA adapters hotswap + svr->Get ("/lora-adapters", handle_lora_adapters_list); + svr->Post("/lora-adapters", handle_lora_adapters_apply); + // Save & load slots + svr->Get ("/slots", handle_slots); + svr->Get ("/slots/list", list_slot_prompts); + if (!params.slot_save_path.empty()) { + // these endpoints rely on slot_save_path existing + svr->Post("/slots/:id_slot", handle_slots_action); + svr->Get ("/list", list_saved_prompts); + svr->Post("/delete_prompt", delete_saved_prompt); + svr->Post("/rename_prompt", rename_saved_prompt); + + } + + svr->Get ("/version", handle_version); + if (!params.sql_save_file.empty()) { + // these endpoints rely on sql_save_file existing + svr->Post("/load", handle_load); + svr->Post("/save", handle_save); + svr->Post("/rename", handle_rename); + svr->Post("/all", handle_all); + svr->Post("/sessions", handle_sessions); + svr->Get ("/sessions", handle_sessions); + svr->Post("/delete", handle_delete); + //VACUUM is there for the extension but does not require the extension + svr->Get ("/vacuum", handle_vacuum); +#ifdef SQLITE3_MODERN_CPP_SUPPORT + if (sqlite_extension_loaded) { + svr->Get ("/zstd_get_configs", handle_zstd_get_configs); + svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance); + svr->Post("/zstd_enable_transparent", handle_zstd_enable); + svr->Post("/zstd_update_transparent", handle_zstd_config_update); + } +#endif + } + // + // Start the server + // + if (params.n_threads_http < 1) { + // +2 threads for monitoring endpoints + params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + log_data["n_threads_http"] = std::to_string(params.n_threads_http); + svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; + + LOG_INFO("HTTP server listening", log_data); + + // run the HTTP server in a thread - see comment below + std::thread t([&]() { + if (!svr->listen_after_bind()) { + state.store(SERVER_STATE_ERROR); + return 1; + } + + return 0; + }); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + ctx_server.queue_tasks.start_loop(); + + svr->stop(); + t.join(); + + llama_backend_free(); + + return 0; +} From 958cd0f98cc8c6f63f9fe0dceea691a231721d5e Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:30:25 +0100 Subject: [PATCH 3/8] Better separation, as requested --- examples/server/server.cpp | 54 ++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6e7657fe8..46b2ac823 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1041,9 +1041,7 @@ int main(int argc, char ** argv) { -// handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server, ¶ms]( +const auto handle_completions_impl = [&ctx_server, ¶ms]( server_task_type type, json& data, const std::vector& files, @@ -1053,7 +1051,7 @@ int main(int argc, char ** argv) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); // ---------------------------------------------------------------- - // 1. Regex Validation + // 1. Regex Validation (Common) // ---------------------------------------------------------------- auto validate_regex_list = [&](const std::string& field_name) -> std::string { if (data.contains(field_name) && data[field_name].is_array()) { @@ -1082,26 +1080,9 @@ int main(int argc, char ** argv) { } const auto completion_id = gen_chatcmplid(); - - // Process prompt / inputs - std::vector inputs; - try { - const auto& prompt = data.at("prompt"); - if (oaicompat && ctx_server.mctx != nullptr) { - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } - else { - inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); - } - } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } // ---------------------------------------------------------------- // Check if we need the complex "Banned String" logic - // Only enable if the lists are present AND contain actual strings. // ---------------------------------------------------------------- auto list_has_content = [&](const std::string& key) { if (data.contains(key) && data[key].is_array()) { @@ -1120,7 +1101,7 @@ int main(int argc, char ** argv) { if (!has_banned_content) { // ---------------------------------------------------------------- - // PATH A: Standard Logic (server_response_reader) + // PATH A: Standard Logic (The "Old Way") // ---------------------------------------------------------------- // need to store the reader as a pointer, so that it won't be destroyed when the handle returns @@ -1129,6 +1110,18 @@ int main(int argc, char ** argv) { try { std::vector tasks; + + const auto& prompt = data.at("prompt"); + + // process prompt + std::vector inputs; + + if (oaicompat && ctx_server.mctx != nullptr) { + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } + else { + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + } tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -1289,6 +1282,22 @@ int main(int argc, char ** argv) { // PATH B: Banned Content Logic (Slow Path with Buffering & Rewind) // ---------------------------------------------------------------- auto buffer_and_check_string_ban_and_rewind_logic = [&]() { + // Process prompt / inputs (Duplicated here to keep Path A isolated) + std::vector inputs; + try { + const auto& prompt = data.at("prompt"); + if (oaicompat && ctx_server.mctx != nullptr) { + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } + else { + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + } + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + // Helper to mimic request_cancel using the task queue directly auto request_cancel = [&ctx_server](int id_target) { server_task task(SERVER_TASK_TYPE_CANCEL); @@ -1950,6 +1959,7 @@ int main(int argc, char ** argv) { + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { auto data = json::parse(req.body); std::vector files; // dummy From b95340aad58293082b0407aff340b2e51e0f1d35 Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Tue, 27 Jan 2026 22:49:29 +0100 Subject: [PATCH 4/8] Comment --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 46b2ac823..c87a22f4f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1039,8 +1039,8 @@ int main(int argc, char ** argv) { }; - - + // handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results const auto handle_completions_impl = [&ctx_server, ¶ms]( server_task_type type, json& data, From 64438351e1172d18533f50ee048dc542720408df Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Thu, 29 Jan 2026 19:33:23 +0100 Subject: [PATCH 5/8] Separate into functions --- examples/server/server.cpp | 1502 +++++++++++++++--------------------- 1 file changed, 643 insertions(+), 859 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c87a22f4f..b1674a1d6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -10,8 +10,7 @@ #include "llama.h" #include "llama-vocab.h" #include -#include -#include + // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -432,6 +431,522 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } + + + + +// --- START OF HELPER FUNCTIONS --- + +static std::string validate_banned_regex(const json& data) { + auto validate_list = [&](const std::string& field_name) -> std::string { + if (data.contains(field_name) && data[field_name].is_array()) { + for (const auto& val : data[field_name]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s); + } catch (const std::regex_error&) { + return s; + } + } + } + } + } + return ""; + }; + + std::string invalid = validate_list("banned_regex"); + if (invalid.empty()) { + invalid = validate_list("banned_regex_case_insensitive"); + } + return invalid; +} + +static void handle_completions_banned_impl( + server_context& ctx_server, + server_task_type type, + json data, // Passed by value for internal modification + server_tokens tokens, + const std::function& is_connection_closed, + httplib::Response& res, + oaicompat_type oaicompat, + std::string completion_id +) { + // Local version of res_err + auto res_err = [](httplib::Response& res, json error_data) { + json final_response{ {"error", error_data} }; + res.set_content(final_response.dump(), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); + }; + + std::string model_name = get_model_name(ctx_server.params_base.model); + + auto create_task = [type, oaicompat, completion_id, model_name](int id, int index, server_tokens t, json& task_data) { + server_task task(type); + task.id = id; + task.index = index; + task.tokens = std::move(t); + task.data = task_data; + task.id_slot = json_value(task_data, "id_slot", -1); + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = model_name; + return task; + }; + + auto request_cancel = [&ctx_server](int id_target) { + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_target; + std::vector tasks; + tasks.push_back(std::move(task)); + ctx_server.queue_tasks.post(std::move(tasks), true); + }; + + auto send_sse = [oaicompat](httplib::DataSink& sink, const json& payload) -> bool { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, payload); + } else { + return server_sent_event(sink, payload); + } + }; + + auto post_single_task = [&ctx_server, create_task](int id_task, json& task_data, server_tokens& t) { + std::vector tasks; + tasks.push_back(create_task(id_task, 0, std::move(t), task_data)); + ctx_server.queue_tasks.post(std::move(tasks)); + }; + + // Initial Task Setup + const int id_task = ctx_server.queue_tasks.get_new_id(); + ctx_server.queue_results.add_waiting_task_id(id_task); + post_single_task(id_task, data, tokens); + + bool stream = json_value(data, "stream", false); + + // Non-Streaming Logic + if (!stream) { + std::unordered_set ids = { id_task }; + server_task_result_ptr result = nullptr; + + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && is_connection_closed()) { + request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + return; + } + } + + if (!result->is_error()) { + json result_json; + if (oaicompat) { + result_json = format_final_response_oaicompat(data, result->data, completion_id, false); + } else { + result_json = result->to_json(); + } + res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } else { + res_err(res, result->to_json()); + } + ctx_server.queue_results.remove_waiting_task_id(id_task); + return; + } + + // Streaming Logic (Buffering & Rewind) + auto active_task_id = std::make_shared(id_task); + + const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, data, request_cancel, post_single_task, send_sse](size_t, httplib::DataSink& sink) mutable { + try { + bool successful_completion = false; + + // --- Parse Banned Config --- + std::vector stop_phrases; + if (data.contains("banned_strings") && data["banned_strings"].is_array()) { + for (const auto& val : data["banned_strings"]) { + if (val.is_string() && !val.get().empty()) { + stop_phrases.push_back(val.get()); + } + } + } + std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { + return a.length() > b.length(); + }); + + std::vector regex_patterns; + std::vector stop_regexes; + auto add_regex = [&](const std::string& field, bool icase) { + if (data.contains(field) && data[field].is_array()) { + for (const auto& val : data[field]) { + if (val.is_string() && !val.get().empty()) { + auto flags = std::regex_constants::ECMAScript; + if (icase) flags |= std::regex_constants::icase; + stop_regexes.emplace_back(val.get(), flags); + regex_patterns.push_back(val.get()); + } + } + } + }; + add_regex("banned_regex", false); + add_regex("banned_regex_case_insensitive", true); + + float ban_bias = json_value(data, "banned_bias", -10000.0f); + size_t manual_buffer_size = json_value(data, "banbuffer_size", (size_t)0); + int original_n_predict = json_value(data, "n_predict", -1); + int total_tokens_streamed = 0; + + // Calculate Buffer Size + size_t BUFFER_SIZE = manual_buffer_size; + if (BUFFER_SIZE == 0) { + size_t max_len = stop_phrases.empty() ? 0 : stop_phrases[0].length(); + for (const auto& pat : regex_patterns) max_len = std::max(max_len, pat.length()); + BUFFER_SIZE = std::max((size_t)1, max_len + 1); + } + + std::deque token_buffer; + int current_task_id = id_task; + std::set current_step_bans; + int ban_slot_index = -1; + std::string current_prompt_str = json_value(data, "prompt", std::string("")); + + auto get_content_str = [](const json& j) -> std::string { + if (j.contains("choices") && !j["choices"].empty()) { + auto& d = j["choices"][0]["delta"]; + if (d.contains("content") && d["content"].is_string()) return d["content"]; + } + if (j.contains("content") && j["content"].is_string()) return j["content"]; + return ""; + }; + + // FIXED: Added type checks to prevent string-to-int conversion errors + auto get_token_id = [](const json& j) -> int { + if (j.contains("__raw_token_id") && j["__raw_token_id"].is_number()) return j["__raw_token_id"]; + if (j.contains("token") && j["token"].is_number()) return j["token"]; + if (j.contains("id") && j["id"].is_number()) return j["id"]; + return -1; + }; + + auto to_lower_str = [](std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + return s; + }; + + auto print_debug_buffer = [&](const std::deque& buf) { + std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; + size_t print_len = std::max(buf.size(), BUFFER_SIZE); + for (size_t i = 0; i < print_len; ++i) { + if (i < buf.size()) { + std::string content = get_content_str(buf[i]); + std::string escaped; + for (char c : content) { + if (c == '\n') escaped += "\\n"; + else if (c == '"') escaped += "\\\""; + else escaped += c; + } + std::cout << "\"" << escaped << "\""; + } else { + std::cout << "\"\""; + } + if (i < print_len - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + }; + + while (true) { + *active_task_id = current_task_id; + + if (!sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + std::unordered_set ids = { current_task_id }; + server_task_result_ptr result = nullptr; + while (!result) { + result = ctx_server.queue_results.recv_with_timeout(ids, 1); + if (!result && !sink.is_writable()) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + } + + std::vector items_to_buffer; + if (!result->is_error()) { + std::vector parts; + if (oaicompat) parts = format_partial_response_oaicompat(*result, completion_id); + else parts.push_back(result->data); + + json raw_data = result->data; + for (const auto& r : parts) { + json item = r; + if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; + items_to_buffer.push_back(item); + } + } else { + items_to_buffer.push_back(result->to_json()); + } + + for (const auto& item : items_to_buffer) token_buffer.push_back(item); + + print_debug_buffer(token_buffer); + + // --- Check for Bans --- + std::string buffer_text = ""; + std::vector token_offsets; + for (const auto& item : token_buffer) { + token_offsets.push_back(buffer_text.length()); + buffer_text += get_content_str(item); + } + + std::string buffer_lower = to_lower_str(buffer_text); + size_t match_pos = std::string::npos; + std::string detected_phrase = ""; + + for (const auto& phrase : stop_phrases) { + size_t pos = buffer_lower.find(to_lower_str(phrase)); + if (pos != std::string::npos && (match_pos == std::string::npos || pos < match_pos)) { + match_pos = pos; + detected_phrase = phrase; + } + } + for (size_t i = 0; i < stop_regexes.size(); ++i) { + std::smatch match; + if (std::regex_search(buffer_text, match, stop_regexes[i])) { + size_t pos = match.position(0); + if (match_pos == std::string::npos || pos < match_pos) { + match_pos = pos; + detected_phrase = "REGEX:" + regex_patterns[i]; + } + } + } + + if (match_pos != std::string::npos) { + std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; + + size_t split_index = 0; + bool found_split = false; + for (size_t i = 0; i < token_offsets.size(); ++i) { + if (token_offsets[i] + get_content_str(token_buffer[i]).length() > match_pos) { + split_index = i; + found_split = true; + break; + } + } + + if (found_split) { + std::string temp_prompt_suffix = ""; + std::deque good_tokens; + for (size_t i = 0; i < split_index; ++i) { + json& item = token_buffer[i]; + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + temp_prompt_suffix += get_content_str(item); + good_tokens.push_back(item); + } + + json& guilty_item = token_buffer[split_index]; + int guilty_token_id = get_token_id(guilty_item); + if (guilty_token_id == -1) { + auto t = ctx_server.tokenize(get_content_str(guilty_item), false); + if (!t.empty()) guilty_token_id = t[0]; + } + + if (guilty_token_id != -1) { + if (ban_slot_index != (int)split_index) { + current_step_bans.clear(); + ban_slot_index = (int)split_index; + } + current_step_bans.insert(guilty_token_id); + std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; + + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + + // Generate Fix Token + json fix_data = data; + fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; + fix_data["n_predict"] = 1; + if (!fix_data.contains("logit_bias")) fix_data["logit_bias"] = json::array(); + + if (fix_data["logit_bias"].is_array()) { + for (int banned_id : current_step_bans) fix_data["logit_bias"].push_back(json::array({ banned_id, ban_bias })); + } else if (fix_data["logit_bias"].is_object()) { + for (int banned_id : current_step_bans) fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; + } + + std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; + + int id_fix = ctx_server.queue_tasks.get_new_id(); + *active_task_id = id_fix; + ctx_server.queue_results.add_waiting_task_id(id_fix); + + auto fix_inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true); + post_single_task(id_fix, fix_data, fix_inputs[0]); + + std::unordered_set fix_ids = { id_fix }; + server_task_result_ptr fix_result = nullptr; + while (!fix_result) { + fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); + if (!fix_result && !sink.is_writable()) { + request_cancel(id_fix); + ctx_server.queue_results.remove_waiting_task_id(id_fix); + return false; + } + } + ctx_server.queue_results.remove_waiting_task_id(id_fix); + + if (fix_result->is_error()) { + std::cout << "Debug: Fix task failed with error." << std::endl; + send_sse(sink, fix_result->to_json()); + return false; + } + + json fix_token_json; + if (oaicompat) { + auto parts = format_partial_response_oaicompat(*fix_result, completion_id); + if (!parts.empty()) fix_token_json = parts[0]; + } else { + fix_token_json = fix_result->data; + } + if (fix_result->data.contains("token")) fix_token_json["__raw_token_id"] = fix_result->data["token"]; + + // Resume Generation + json resume_data = data; + bool stop_after_fix = false; + if (original_n_predict > 0) { + int pending = good_tokens.size() + 1; + if (total_tokens_streamed + pending >= original_n_predict) stop_after_fix = true; + else resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); + } + + if (stop_after_fix) { + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + while (!token_buffer.empty()) { + json& item = token_buffer.front(); + if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); + if (!send_sse(sink, item)) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return false; + } + total_tokens_streamed++; + token_buffer.pop_front(); + } + successful_completion = true; + goto cleanup; + } + + resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + get_content_str(fix_token_json); + current_task_id = ctx_server.queue_tasks.get_new_id(); + *active_task_id = current_task_id; + ctx_server.queue_results.add_waiting_task_id(current_task_id); + + auto resume_inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true); + post_single_task(current_task_id, resume_data, resume_inputs[0]); + + token_buffer = good_tokens; + token_buffer.push_back(fix_token_json); + } + } + } + + // Flush Logic + bool should_flush_all = result->is_stop() || result->is_error(); + if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { + while (!token_buffer.empty()) { + if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) break; + + json& item_to_send = token_buffer.front(); + if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); + + current_prompt_str += get_content_str(item_to_send); + + if (ban_slot_index != -1) { + if (0 == ban_slot_index) { + current_step_bans.clear(); + ban_slot_index = -1; + } else { + ban_slot_index--; + } + } + + if (!send_sse(sink, item_to_send)) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + + total_tokens_streamed++; + token_buffer.pop_front(); + + if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { + request_cancel(current_task_id); + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + successful_completion = true; + goto cleanup; + } + } + } + + if (result->is_error()) { + ctx_server.queue_results.remove_waiting_task_id(current_task_id); + return false; + } + if (result->is_stop()) { + successful_completion = true; + break; + } + } + + cleanup: + bool ok = true; + if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string done_message = "data: [DONE]\n\n"; + LOG_VERBOSE("data stream", { {"to_send", done_message} }); + if (!sink.write(done_message.c_str(), done_message.size())) ok = false; + } + sink.done(); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + return ok; + + } catch (const std::exception& e) { + std::cerr << "Exception in streaming handler: " << e.what() << std::endl; + send_sse(sink, json{ {"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}} }); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } catch (...) { + std::cerr << "Unknown exception in streaming handler" << std::endl; + send_sse(sink, json{ {"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}} }); + sink.done(); + if (active_task_id) { + request_cancel(*active_task_id); + ctx_server.queue_results.remove_waiting_task_id(*active_task_id); + } + return false; + } + }; + + auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { + int id_to_cancel = *active_task_id; + request_cancel(id_to_cancel); + ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); +} + +// --- END OF HELPER FUNCTIONS --- + + + + + int main(int argc, char ** argv) { #if SERVER_VERBOSE != 1 log_disable(); @@ -1039,6 +1554,7 @@ int main(int argc, char ** argv) { }; + // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results const auto handle_completions_impl = [&ctx_server, ¶ms]( @@ -1050,40 +1566,16 @@ const auto handle_completions_impl = [&ctx_server, ¶ms]( oaicompat_type oaicompat) -> void { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - // ---------------------------------------------------------------- - // 1. Regex Validation (Common) - // ---------------------------------------------------------------- - auto validate_regex_list = [&](const std::string& field_name) -> std::string { - if (data.contains(field_name) && data[field_name].is_array()) { - for (const auto& val : data[field_name]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) { - try { - std::regex re(s); - } catch (const std::regex_error& e) { - return s; - } - } - } - } - } - return ""; - }; - - std::string invalid_re = validate_regex_list("banned_regex"); - if (invalid_re.empty()) invalid_re = validate_regex_list("banned_regex_case_insensitive"); - + // 1. Common Validation + std::string invalid_re = validate_banned_regex(data); if (!invalid_re.empty()) { res_err(res, format_error_response("Invalid regex: " + invalid_re, ERROR_TYPE_INVALID_REQUEST)); return; } const auto completion_id = gen_chatcmplid(); + std::string model_name = get_model_name(ctx_server.params_base.model); - // ---------------------------------------------------------------- - // Check if we need the complex "Banned String" logic - // ---------------------------------------------------------------- auto list_has_content = [&](const std::string& key) { if (data.contains(key) && data[key].is_array()) { for (const auto& item : data[key]) { @@ -1099,861 +1591,153 @@ const auto handle_completions_impl = [&ctx_server, ¶ms]( list_has_content("banned_regex") || list_has_content("banned_regex_case_insensitive"); - if (!has_banned_content) { - // ---------------------------------------------------------------- - // PATH A: Standard Logic (The "Old Way") - // ---------------------------------------------------------------- - - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); + // 2. Common Prompt Processing + std::vector inputs; + try { + const auto& prompt = data.at("prompt"); + if (oaicompat && ctx_server.mctx != nullptr) { + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } + else { + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + } + } + catch (const std::exception& e) { + res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } - try { - std::vector tasks; + // 3. Branching Logic + if (has_banned_content) { + // PATH B: Complex Logic (Delegated to external function) + handle_completions_banned_impl(ctx_server, type, data, std::move(inputs[0]), is_connection_closed, res, oaicompat, completion_id); + return; + } - const auto& prompt = data.at("prompt"); + // PATH A: Standard Logic (Inline) + auto create_task = [type, oaicompat, completion_id, model_name](int id, int index, server_tokens tokens, json& task_data) { + server_task task(type); + task.id = id; + task.index = index; + task.tokens = std::move(tokens); + task.data = task_data; + task.id_slot = json_value(task_data, "id_slot", -1); + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = model_name; + return task; + }; - // process prompt - std::vector inputs; + auto send_sse = [oaicompat](httplib::DataSink& sink, const json& payload) -> bool { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, payload); + } else { + return server_sent_event(sink, payload); + } + }; - if (oaicompat && ctx_server.mctx != nullptr) { - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } - else { - inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.data = data; - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); - tasks.push_back(std::move(task)); - } + bool stream = json_value(data, "stream", false); + const auto rd = std::make_shared(ctx_server); + std::vector tasks; + tasks.reserve(inputs.size()); + + for (size_t i = 0; i < inputs.size(); i++) { + tasks.push_back(create_task(ctx_server.queue_tasks.get_new_id(), i, std::move(inputs[i]), data)); + } + rd->post_tasks(std::move(tasks)); - rd->post_tasks(std::move(tasks)); + if (!stream) { + auto all_results = rd->wait_for_all(is_connection_closed); + if (all_results.is_terminated) { + llama_decode_stop(); + return; } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + else if (all_results.error) { + res_err(res, all_results.error->to_json()); return; } - bool stream = json_value(data, "stream", false); - if (!stream) { - // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); - if (all_results.is_terminated) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } - else { - json arr = json::array(); - for (auto& res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - if (oaicompat) { - arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); - } else { - arr.push_back(res->to_json()); - } - } - // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); - } - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); - if (first_result == nullptr) { - llama_decode_stop(); // send a signal to stop decode process - return; // connection is closed - } - else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; - } - else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // Prepare first result JSON (handling OAI format if needed) - std::vector first_result_parts; - if (oaicompat) { - first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); - } else { - first_result_parts.push_back(first_result->to_json()); - } - - const auto chunked_content_provider = [first_result_parts, rd, oaicompat, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { - const auto sse = [oaicompat, &sink](const json& res) { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, res); - } - else { - return server_sent_event(sink, res); - } - }; - - // flush the first result parts - for (auto& part : first_result_parts) { - if (!part.empty()) { - if (!sse(part)) { - sink.done(); - return false; // sending failed, go to on_complete() - } - part.clear(); // mark as sent - } - } - - // receive subsequent results - auto result = rd->next([&sink] { return !sink.is_writable(); }); - if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() - } - - // send the results - bool ok = false; - if (result->is_error()) { - ok = sse(json{ { "error", result->to_json() } }); - sink.done(); - return false; // go to on_complete() + json arr = json::array(); + for (auto& res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + if (oaicompat) { + arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); + } else { + arr.push_back(res->to_json()); } - else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - - if (oaicompat) { - std::vector parts = format_partial_response_oaicompat(*result, completion_id); - for (const auto& part : parts) { - ok = sse(part); - if (!ok) break; - } - } else { - ok = sse(result->to_json()); - } - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() - } - - // has next data, continue - return true; - }; + } + res_ok(res, arr.size() == 1 ? arr[0] : arr); + } + } + else { + server_task_result_ptr first_result = rd->next(is_connection_closed); + if (first_result == nullptr) { + llama_decode_stop(); + return; + } + else if (first_result->is_error()) { + res_err(res, first_result->to_json()); + return; + } - auto on_complete = [rd](bool) { - rd->stop(); - }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + std::vector first_result_parts; + if (oaicompat) { + first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); + } else { + first_result_parts.push_back(first_result->to_json()); } - } else { - // ---------------------------------------------------------------- - // PATH B: Banned Content Logic (Slow Path with Buffering & Rewind) - // ---------------------------------------------------------------- - auto buffer_and_check_string_ban_and_rewind_logic = [&]() { - // Process prompt / inputs (Duplicated here to keep Path A isolated) - std::vector inputs; - try { - const auto& prompt = data.at("prompt"); - if (oaicompat && ctx_server.mctx != nullptr) { - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } - else { - inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + const auto chunked_content_provider = [first_result_parts, rd, oaicompat, send_sse, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { + for (auto& part : first_result_parts) { + if (!part.empty()) { + if (!send_sse(sink, part)) { + sink.done(); + return false; + } + part.clear(); } } - catch (const std::exception& e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } - // Helper to mimic request_cancel using the task queue directly - auto request_cancel = [&ctx_server](int id_target) { - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_target; - std::vector tasks; - tasks.push_back(std::move(task)); - ctx_server.queue_tasks.post(std::move(tasks), true); - }; - - // Helper to post a completion task with correct OAI params - auto post_task_with_params = [&ctx_server, oaicompat, completion_id](int id_task, json& task_data, server_tokens& tokens) { - server_task task(SERVER_TASK_TYPE_COMPLETION); - task.id = id_task; - task.index = 0; - task.tokens = std::move(tokens); - task.data = task_data; - task.id_slot = json_value(task_data, "id_slot", -1); - - // Critical: Set OAI params so worker generates correct output format - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); - - std::vector tasks; - tasks.push_back(std::move(task)); - ctx_server.queue_tasks.post(std::move(tasks)); - }; + auto result = rd->next([&sink] { return !sink.is_writable(); }); + if (result == nullptr) { + sink.done(); + return false; + } - const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - - // Use helper instead of request_completion - post_task_with_params(id_task, data, inputs[0]); - - bool stream = json_value(data, "stream", false); + if (result->is_error()) { + send_sse(sink, json{ { "error", result->to_json() } }); + sink.done(); + return false; + } - if (!stream) { - // Non-streaming: wait for result (using pointer to avoid slicing) - std::unordered_set ids = { id_task }; - server_task_result_ptr result = nullptr; - - // Simple blocking wait - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && is_connection_closed()) { - request_cancel(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - return; - } - } - - if (!result->is_error()) { - json result_json; - if (oaicompat) { - result_json = format_final_response_oaicompat(data, result->data, completion_id, false); - } else { - result_json = result->to_json(); + if (oaicompat) { + std::vector parts = format_partial_response_oaicompat(*result, completion_id); + for (const auto& part : parts) { + if (!send_sse(sink, part)) { + sink.done(); + return false; } - res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } - else { - res_err(res, result->to_json()); + } else { + if (!send_sse(sink, result->to_json())) { + sink.done(); + return false; } - ctx_server.queue_results.remove_waiting_task_id(id_task); } - else { - // Shared state to track the currently running task ID across retries. - auto active_task_id = std::make_shared(id_task); - - // Capture 'data' by value to use as a template for retries - const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, send_done = params.send_done, data, request_cancel, post_task_with_params](size_t, httplib::DataSink& sink) mutable { - // Define sse here so it's visible to both try and catch blocks - const auto sse = [oaicompat, &sink](const json &res) { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, res); - } else { - return server_sent_event(sink, res); - } - }; - - try { - bool successful_completion = false; - - // 1. Parse Configuration from Request - // Banned Strings - std::vector stop_phrases; - if (data.contains("banned_strings") && data["banned_strings"].is_array()) { - for (const auto& val : data["banned_strings"]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) stop_phrases.push_back(s); - } - } - } - - // Sort banned strings by length (descending) - std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { - return a.length() > b.length(); - }); - - // Banned Regex (Case Sensitive & Insensitive) - std::vector regex_patterns; // For buffer size calculation - std::vector stop_regexes; // Compiled regexes - - auto add_regex_list = [&](const std::string& field_name, bool case_insensitive) { - if (data.contains(field_name) && data[field_name].is_array()) { - for (const auto& val : data[field_name]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) { - auto flags = std::regex_constants::ECMAScript; - if (case_insensitive) flags |= std::regex_constants::icase; - stop_regexes.emplace_back(s, flags); - regex_patterns.push_back(s); - } - } - } - } - }; - - // We assume validation passed in handle_completions_impl, so no try-catch needed here - add_regex_list("banned_regex", false); - add_regex_list("banned_regex_case_insensitive", true); - - // Logit Bias Penalty (Default: -10000.0) - float ban_bias = -10000.0f; - if (data.contains("banned_bias") && data["banned_bias"].is_number()) { - ban_bias = data["banned_bias"].get(); - } - - // Manual Buffer Size - size_t manual_buffer_size = 0; - if (data.contains("banbuffer_size") && data["banbuffer_size"].is_number_unsigned()) { - manual_buffer_size = data["banbuffer_size"].get(); - } - - // Token Limit Tracking - int original_n_predict = -1; - if (data.contains("n_predict") && data["n_predict"].is_number_integer()) { - original_n_predict = data["n_predict"].get(); - } - int total_tokens_streamed = 0; - - // ============================================================ - // FAST PATH: No banned strings AND No regex -> No buffering - // ============================================================ - if (stop_phrases.empty() && stop_regexes.empty()) { - while (true) { - std::unordered_set ids = { *active_task_id }; - server_task_result_ptr result = nullptr; - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && !sink.is_writable()) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - - if (!result->is_error()) { - // Use format_partial_response_oaicompat to get the correct chunks - std::vector parts; - if (oaicompat) { - parts = format_partial_response_oaicompat(*result, completion_id); - } else { - parts.push_back(result->data); - } - - for (const auto& item : parts) { - if (!sse(item)) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - - if (result->is_stop()) { - successful_completion = true; - break; - } - } else { - sse(result->to_json()); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - } - } - // ============================================================ - // SLOW PATH: Buffering and Banning Logic - // ============================================================ - else { - // Calculate Buffer Size - size_t BUFFER_SIZE; - if (manual_buffer_size > 0) { - BUFFER_SIZE = manual_buffer_size; - } else { - size_t max_len = 0; - // Check strings - if (!stop_phrases.empty()) { - max_len = stop_phrases[0].length(); // First is longest due to sort - } - // Check regex patterns - for (const auto& pat : regex_patterns) { - if (pat.length() > max_len) max_len = pat.length(); - } - - // Default: Longest string/regex + 1 - BUFFER_SIZE = std::max((size_t)1, max_len + 1); - } - - // Initialize Buffer & State - std::deque token_buffer; - - int current_task_id = id_task; - - // Track bans specifically for the current "next token" to be generated. - std::set current_step_bans; - int ban_slot_index = -1; - - // Track the text that has been confirmed/sent to the client. - std::string current_prompt_str = ""; - if (data.contains("prompt") && data["prompt"].is_string()) { - current_prompt_str = data["prompt"].get(); - } - - // Helper to extract text content - auto get_content_str = [](const json& j) -> std::string { - if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) { - const auto& choice = j["choices"][0]; - if (choice.contains("delta") && choice["delta"].contains("content")) { - auto val = choice["delta"]["content"]; - if (val.is_string()) return val.get(); - } - } - if (j.contains("content")) { - auto val = j["content"]; - if (val.is_string()) return val.get(); - } - return ""; - }; - - // Helper to extract Token ID - auto get_token_id = [](const json& j) -> int { - if (j.contains("__raw_token_id")) return j["__raw_token_id"].get(); - if (j.contains("token")) return j["token"].get(); - if (j.contains("id")) return j["id"].get(); - return -1; - }; - - // Helper for case-insensitive search - auto to_lower_str = [](std::string s) { - std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c){ return std::tolower(c); }); - return s; - }; - - // Helper to print buffer - auto print_debug_buffer = [&](const std::deque& buf) { - std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; - size_t print_len = std::max(buf.size(), BUFFER_SIZE); - for (size_t i = 0; i < print_len; ++i) { - if (i < buf.size()) { - std::string content = get_content_str(buf[i]); - std::string escaped; - for (char c : content) { - if (c == '\n') escaped += "\\n"; - else if (c == '"') escaped += "\\\""; - else escaped += c; - } - std::cout << "\"" << escaped << "\""; - } else { - std::cout << "\"\""; - } - if (i < print_len - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - }; - - while (true) { - // Ensure shared state matches current local state - *active_task_id = current_task_id; - - // 0. Check connection status explicitly - if (!sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - // Receive from the CURRENT task ID using pointer to avoid slicing - std::unordered_set ids = { current_task_id }; - server_task_result_ptr result = nullptr; - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && !sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - } - - std::vector items_to_buffer; - - if (!result->is_error()) { - // Use format_partial_response_oaicompat to get the correct chunks - std::vector parts; - if (oaicompat) { - parts = format_partial_response_oaicompat(*result, completion_id); - } else { - parts.push_back(result->data); - } - - json raw_data = result->data; // Access raw data for token ID - - for (const auto& r : parts) { - json item = r; - // Attach raw token ID for banning logic - if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; - items_to_buffer.push_back(item); - } - } else { - items_to_buffer.push_back(result->to_json()); - } - - // 2. Process items into buffer - for (const auto& item : items_to_buffer) { - token_buffer.push_back(item); - } - - print_debug_buffer(token_buffer); - - // 3. Check for Stop Phrases (Strings & Regex) - std::string buffer_text = ""; - std::vector token_offsets; - - for (const auto& item : token_buffer) { - token_offsets.push_back(buffer_text.length()); - buffer_text += get_content_str(item); - } - - std::string buffer_lower = to_lower_str(buffer_text); - - size_t match_pos = std::string::npos; - std::string detected_phrase = ""; - - // A. Check Strings (Case Insensitive) - for (const auto& phrase : stop_phrases) { - std::string target_lower = to_lower_str(phrase); - size_t pos = buffer_lower.find(target_lower); - if (pos != std::string::npos) { - if (match_pos == std::string::npos || pos < match_pos) { - match_pos = pos; - detected_phrase = phrase; - } - } - } - - // B. Check Regex - for (size_t i = 0; i < stop_regexes.size(); ++i) { - std::smatch match; - // We search the raw buffer_text - if (std::regex_search(buffer_text, match, stop_regexes[i])) { - size_t pos = match.position(0); - if (match_pos == std::string::npos || pos < match_pos) { - match_pos = pos; - detected_phrase = "REGEX:" + regex_patterns[i]; - } - } - } - - if (match_pos != std::string::npos) { - std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; - - // Find the guilty token - size_t split_index = 0; - bool found_split = false; - for (size_t i = 0; i < token_offsets.size(); ++i) { - size_t token_start = token_offsets[i]; - std::string content = get_content_str(token_buffer[i]); - size_t token_end = token_start + content.length(); - - if (token_end > match_pos) { - split_index = i; - found_split = true; - break; - } - } - - if (found_split) { - // 1. Construct prompt from good tokens (DO NOT FLUSH) - std::string temp_prompt_suffix = ""; - std::deque good_tokens; - - for (size_t i = 0; i < split_index; ++i) { - json& item = token_buffer[i]; - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - temp_prompt_suffix += get_content_str(item); - good_tokens.push_back(item); - } - - // 2. Identify Guilty Token & Add to Bans - json& guilty_item = token_buffer[split_index]; - int guilty_token_id = get_token_id(guilty_item); - - if (guilty_token_id == -1) { - std::string content = get_content_str(guilty_item); - auto tokens = ctx_server.tokenize(content, false); - if (!tokens.empty()) guilty_token_id = tokens[0]; - } - - if (guilty_token_id != -1) { - // Check if we are banning a different slot than before - if (ban_slot_index != (int)split_index) { - current_step_bans.clear(); - ban_slot_index = (int)split_index; - } - - current_step_bans.insert(guilty_token_id); - std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; - - // 3. Cancel current task - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - - // 4. FIX STEP: Generate 1 token with ALL current bans - json fix_data = data; - fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; - fix_data["n_predict"] = 1; - - // Robust logit_bias handling - if (!fix_data.contains("logit_bias")) { - fix_data["logit_bias"] = json::array(); - } - - if (fix_data["logit_bias"].is_array()) { - for (int banned_id : current_step_bans) { - fix_data["logit_bias"].push_back(json::array({banned_id, ban_bias})); - } - } else if (fix_data["logit_bias"].is_object()) { - for (int banned_id : current_step_bans) { - fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; - } - } - - std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; - - int id_fix = ctx_server.queue_tasks.get_new_id(); - *active_task_id = id_fix; // Update shared state for fix task - ctx_server.queue_results.add_waiting_task_id(id_fix); - - std::vector fix_inputs = tokenize_input_prompts( - llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true - ); - - // Use helper - post_task_with_params(id_fix, fix_data, fix_inputs[0]); - - // Wait for the fix token - std::unordered_set fix_ids = { id_fix }; - server_task_result_ptr fix_result = nullptr; - while (!fix_result) { - fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); - if (!fix_result && !sink.is_writable()) { - request_cancel(id_fix); - ctx_server.queue_results.remove_waiting_task_id(id_fix); - return false; - } - } - ctx_server.queue_results.remove_waiting_task_id(id_fix); - - // Check for error in fix result - if (fix_result->is_error()) { - std::cout << "Debug: Fix task failed with error." << std::endl; - sse(fix_result->to_json()); - return false; - } - - // Process fix token - json fix_token_json; - json raw_fix = fix_result->data; - - // Use format_partial_response_oaicompat for fix token too - if (oaicompat) { - std::vector parts = format_partial_response_oaicompat(*fix_result, completion_id); - if (!parts.empty()) fix_token_json = parts[0]; - } else { - fix_token_json = fix_result->data; - } - - if (raw_fix.contains("token")) fix_token_json["__raw_token_id"] = raw_fix["token"]; - - std::string fix_content = get_content_str(fix_token_json); - - // 5. RESUME STEP: Continue generation normally - json resume_data = data; - bool stop_after_fix = false; - - if (original_n_predict > 0) { - int pending = good_tokens.size() + 1; - if (total_tokens_streamed + pending >= original_n_predict) { - stop_after_fix = true; - } else { - resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); - } - } - - if (stop_after_fix) { - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - - while (!token_buffer.empty()) { - json& item = token_buffer.front(); - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - if (!sse(item)) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - total_tokens_streamed++; - token_buffer.pop_front(); - } - successful_completion = true; - goto cleanup; - } - - resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + fix_content; - - current_task_id = ctx_server.queue_tasks.get_new_id(); - *active_task_id = current_task_id; // Update shared state for resume task - ctx_server.queue_results.add_waiting_task_id(current_task_id); - - std::vector resume_inputs = tokenize_input_prompts( - llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true - ); - - // Use helper - post_task_with_params(current_task_id, resume_data, resume_inputs[0]); - - // 6. Update Buffer: Good Tokens + Fix Token - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - - // REMOVED continue; to allow flush logic to run - } - } - } - - // 4. Standard Flush Logic - bool should_flush_all = result->is_stop() || result->is_error(); - - if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { - while (!token_buffer.empty()) { - if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) { - break; - } - - json& item_to_send = token_buffer.front(); - if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); - - current_prompt_str += get_content_str(item_to_send); - - // SMART BAN CLEARING LOGIC - if (ban_slot_index != -1) { - if (0 == ban_slot_index) { - // We are flushing the slot that had bans. - // This means it's now accepted (or we are forced to flush). - current_step_bans.clear(); - ban_slot_index = -1; - } else { - // We are flushing a preceding token. - // The banned slot shifts left. - ban_slot_index--; - } - } - - if (!sse(item_to_send)) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - total_tokens_streamed++; - token_buffer.pop_front(); - - if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - successful_completion = true; - goto cleanup; - } - } - } - - if (result->is_error()) { - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - if (result->is_stop()) { - successful_completion = true; - break; - } - } - } - - cleanup: - bool ok = true; - if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string done_message = "data: [DONE]\n\n"; - LOG_VERBOSE("data stream", { {"to_send", done_message} }); - if (!sink.write(done_message.c_str(), done_message.size())) { - ok = false; - } - } - sink.done(); - - // Cleanup the active task ID (which might be different from id_task in slow path) - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - - return ok; - } catch (const std::exception& e) { - // Catch any exceptions to prevent crashing the server - std::cerr << "Exception in streaming handler: " << e.what() << std::endl; - sse(json{{"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}}}); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } catch (...) { - std::cerr << "Unknown exception in streaming handler" << std::endl; - sse(json{{"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}}}); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } - }; - - auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { - // Cancel the currently active task ID - int id_to_cancel = *active_task_id; - request_cancel(id_to_cancel); - ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + if (!rd->has_next()) { + if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } + sink.done(); + return false; } + return true; }; - - // Execute the complex logic - buffer_and_check_string_ban_and_rewind_logic(); + + auto on_complete = [rd](bool) { rd->stop(); }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }; From 9c5c506541c3b1a0aa13126af35462e2612b1293 Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Fri, 6 Feb 2026 05:35:22 +0100 Subject: [PATCH 6/8] Update1 --- examples/server/server-context.cpp | 155 ++++++- examples/server/server-context.h | 6 +- examples/server/server.cpp | 720 ++++------------------------- 3 files changed, 229 insertions(+), 652 deletions(-) diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 8e71b2db0..5b97127ad 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -11,6 +11,7 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include server_context::~server_context() { if (ctx) { @@ -1123,23 +1124,25 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) { // ban string + int32_t banbuffer_size = json_value(data, "banbuffer_size", 0); + slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot + const auto& banned_strings = data.find("banned_strings"); if (banned_strings != data.end() && banned_strings->is_array()) { - slot.ban_phrases.clear(); + slot.ban_phrases.clear(); for (const auto& val : data["banned_strings"]) { if (val.is_string()) { std::string s = val.get(); if (!s.empty()) { s = string_lower(s); - auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); - if (ban_tokens.size() > slot.n_buffer) { - slot.n_buffer = ban_tokens.size(); + // Use string length instead of token count + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); } slot.ban_phrases.push_back(s); } } } - slot.n_buffer = slot.n_buffer + 3; // extra buffer in case std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) { return a.length() > b.length(); }); @@ -1149,24 +1152,77 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) { return a.length() > b.length(); }); + for (auto & val : params_base.ban_phrases) { if (!val.empty()) { val = string_lower(val); - auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true); - if (ban_tokens.size() > slot.n_buffer) { - slot.n_buffer = ban_tokens.size(); + // Use string length instead of token count + if (val.length() > slot.n_buffer) { + slot.n_buffer = val.length(); } slot.ban_phrases.push_back(val); } } - slot.n_buffer = slot.n_buffer + 3; // extra buffer in case - params_base.n_buffer = slot.n_buffer; + params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1 } else { slot.ban_phrases = params_base.ban_phrases; slot.n_buffer = params_base.n_buffer; } } - slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore + + // ban regex + slot.ban_regex.clear(); + const auto& banned_regex = data.find("banned_regex"); + if (banned_regex != data.end() && banned_regex->is_array()) { + for (const auto& val : data["banned_regex"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s); + slot.ban_regex.push_back(s); + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); + } + } catch (const std::regex_error& e) { + send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + } + } + + // ban regex case insensitive + slot.ban_regex_ci.clear(); + const auto& banned_regex_ci = data.find("banned_regex_case_insensitive"); + if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) { + for (const auto& val : data["banned_regex_case_insensitive"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s, std::regex_constants::icase); + slot.ban_regex_ci.push_back(s); + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); + } + } catch (const std::regex_error& e) { + send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + } + } + + if (banbuffer_size > 0) { + slot.n_buffer = banbuffer_size; + } else { + slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1 + } + + slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); slot.banned_n = json_value(data, "banned_n", params_base.banned_n); } @@ -2958,14 +3014,20 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_ void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; + bool released = false; for (auto& it : results) { bool has_next = process_token(it, slot); count++; if (!has_next) { + // If stopped by limit, continue processing the buffer to ensure all generated tokens are sent + if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { + continue; + } slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + released = true; break; } if (n > 0 && count >= n) { @@ -2973,6 +3035,14 @@ void server_context::send_token_results(completion_token_outputs& results, serve } } + // If we finished the loop due to limit (and didn't release yet), release now + if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + } + if (count > 0) { slot.sampled = results[results.size()-1].tok; results.erase(results.begin(), results.begin() + count); @@ -2983,7 +3053,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve inline int32_t check_ban_phrase(const server_slot& slot) { bool found = false; size_t n = slot.token_buffer.size(); - size_t start; + size_t best_start = std::string::npos; int32_t n_rewind = 0; std::string string_buffer; llama_tokens tokens; @@ -2991,19 +3061,56 @@ inline int32_t check_ban_phrase(const server_slot& slot) { string_buffer = string_buffer + it.text_to_send; tokens.push_back(it.tok); } - string_buffer = string_lower(string_buffer); + std::string string_buffer_lower = string_lower(string_buffer); + + // Check strings for (auto it : slot.ban_phrases) { - start = string_buffer.find(it); - // has been sorted from longest to shortest + size_t start = string_buffer_lower.find(it); if (start != std::string::npos) { - found = true; - break; + if (start < best_start) { + best_start = start; + found = true; + } + } + } + + // Check regex + for (const auto& pattern : slot.ban_regex) { + try { + std::regex re(pattern); + std::smatch match; + if (std::regex_search(string_buffer, match, re)) { + if (match.position() < best_start) { + best_start = match.position(); + found = true; + } + } + } catch (const std::regex_error& e) { + // Should be caught during validation, but safe fallback + continue; + } + } + + // Check regex case insensitive + for (const auto& pattern : slot.ban_regex_ci) { + try { + std::regex re(pattern, std::regex_constants::icase); + std::smatch match; + if (std::regex_search(string_buffer, match, re)) { + if (match.position() < best_start) { + best_start = match.position(); + found = true; + } + } + } catch (const std::regex_error& e) { + continue; } } + if (found) { std::vector unused; - LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str()); - n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused); + LLAMA_LOG_DEBUG("Banned string/regex dectected: %s\n ", string_buffer.substr(best_start).c_str()); + n = find_n_tokens_from_string(slot.ctx, tokens, best_start, 0, unused); n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n; } return n_rewind; @@ -3035,6 +3142,10 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) { size_t n_keep = slot.cache_tokens.size() - n_rewind; slot.sampled = slot.cache_tokens[n_keep]; slot.cache_tokens.keep_first(n_keep); + + // Adjust decoded count so user doesn't lose budget on banned tokens + slot.n_decoded -= n_rewind; + if (slot.n_decoded < 0) slot.n_decoded = 0; } void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) { @@ -3047,11 +3158,11 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ if (!slot.rewind_status) { slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias } - if (slot.ban_phrases.size() > 0) { + if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) { n_rewind = check_ban_phrase(slot); } // if found string in the ban - if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) { + if (n_rewind > 0) { rewind_context(slot, n_rewind); slot.rewind_status = true; } @@ -3257,4 +3368,4 @@ json server_context::model_meta() const { {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)}, }; -} +} \ No newline at end of file diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a5676e045..088e3464b 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -88,7 +88,9 @@ struct server_slot { int32_t rewind_count = 0; bool rewind_status = false; std::unordered_map logit_bias; - std::vectorban_phrases; + std::vector ban_phrases; + std::vector ban_regex; + std::vector ban_regex_ci; completion_token_outputs token_buffer; float ban_phrases_bias = 0; int32_t banned_n = 1; @@ -336,4 +338,4 @@ struct server_context { // Re-aggregates all active vectors and updates the model state bool apply_control_vectors_internal(); -}; +}; \ No newline at end of file diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5290dcfac..ee8edd7b6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -431,522 +431,6 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } - - - - -// --- START OF HELPER FUNCTIONS --- - -static std::string validate_banned_regex(const json& data) { - auto validate_list = [&](const std::string& field_name) -> std::string { - if (data.contains(field_name) && data[field_name].is_array()) { - for (const auto& val : data[field_name]) { - if (val.is_string()) { - std::string s = val.get(); - if (!s.empty()) { - try { - std::regex re(s); - } catch (const std::regex_error&) { - return s; - } - } - } - } - } - return ""; - }; - - std::string invalid = validate_list("banned_regex"); - if (invalid.empty()) { - invalid = validate_list("banned_regex_case_insensitive"); - } - return invalid; -} - -static void handle_completions_banned_impl( - server_context& ctx_server, - server_task_type type, - json data, // Passed by value for internal modification - server_tokens tokens, - const std::function& is_connection_closed, - httplib::Response& res, - oaicompat_type oaicompat, - std::string completion_id -) { - // Local version of res_err - auto res_err = [](httplib::Response& res, json error_data) { - json final_response{ {"error", error_data} }; - res.set_content(final_response.dump(), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - - std::string model_name = get_model_name(ctx_server.params_base.model); - - auto create_task = [type, oaicompat, completion_id, model_name](int id, int index, server_tokens t, json& task_data) { - server_task task(type); - task.id = id; - task.index = index; - task.tokens = std::move(t); - task.data = task_data; - task.id_slot = json_value(task_data, "id_slot", -1); - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = model_name; - return task; - }; - - auto request_cancel = [&ctx_server](int id_target) { - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_target; - std::vector tasks; - tasks.push_back(std::move(task)); - ctx_server.queue_tasks.post(std::move(tasks), true); - }; - - auto send_sse = [oaicompat](httplib::DataSink& sink, const json& payload) -> bool { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, payload); - } else { - return server_sent_event(sink, payload); - } - }; - - auto post_single_task = [&ctx_server, create_task](int id_task, json& task_data, server_tokens& t) { - std::vector tasks; - tasks.push_back(create_task(id_task, 0, std::move(t), task_data)); - ctx_server.queue_tasks.post(std::move(tasks)); - }; - - // Initial Task Setup - const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - post_single_task(id_task, data, tokens); - - bool stream = json_value(data, "stream", false); - - // Non-Streaming Logic - if (!stream) { - std::unordered_set ids = { id_task }; - server_task_result_ptr result = nullptr; - - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && is_connection_closed()) { - request_cancel(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - return; - } - } - - if (!result->is_error()) { - json result_json; - if (oaicompat) { - result_json = format_final_response_oaicompat(data, result->data, completion_id, false); - } else { - result_json = result->to_json(); - } - res.set_content(result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } else { - res_err(res, result->to_json()); - } - ctx_server.queue_results.remove_waiting_task_id(id_task); - return; - } - - // Streaming Logic (Buffering & Rewind) - auto active_task_id = std::make_shared(id_task); - - const auto chunked_content_provider = [id_task, active_task_id, &ctx_server, completion_id, oaicompat, data, request_cancel, post_single_task, send_sse](size_t, httplib::DataSink& sink) mutable { - try { - bool successful_completion = false; - - // --- Parse Banned Config --- - std::vector stop_phrases; - if (data.contains("banned_strings") && data["banned_strings"].is_array()) { - for (const auto& val : data["banned_strings"]) { - if (val.is_string() && !val.get().empty()) { - stop_phrases.push_back(val.get()); - } - } - } - std::sort(stop_phrases.begin(), stop_phrases.end(), [](const std::string& a, const std::string& b) { - return a.length() > b.length(); - }); - - std::vector regex_patterns; - std::vector stop_regexes; - auto add_regex = [&](const std::string& field, bool icase) { - if (data.contains(field) && data[field].is_array()) { - for (const auto& val : data[field]) { - if (val.is_string() && !val.get().empty()) { - auto flags = std::regex_constants::ECMAScript; - if (icase) flags |= std::regex_constants::icase; - stop_regexes.emplace_back(val.get(), flags); - regex_patterns.push_back(val.get()); - } - } - } - }; - add_regex("banned_regex", false); - add_regex("banned_regex_case_insensitive", true); - - float ban_bias = json_value(data, "banned_bias", -10000.0f); - size_t manual_buffer_size = json_value(data, "banbuffer_size", (size_t)0); - int original_n_predict = json_value(data, "n_predict", -1); - int total_tokens_streamed = 0; - - // Calculate Buffer Size - size_t BUFFER_SIZE = manual_buffer_size; - if (BUFFER_SIZE == 0) { - size_t max_len = stop_phrases.empty() ? 0 : stop_phrases[0].length(); - for (const auto& pat : regex_patterns) max_len = std::max(max_len, pat.length()); - BUFFER_SIZE = std::max((size_t)1, max_len + 1); - } - - std::deque token_buffer; - int current_task_id = id_task; - std::set current_step_bans; - int ban_slot_index = -1; - std::string current_prompt_str = json_value(data, "prompt", std::string("")); - - auto get_content_str = [](const json& j) -> std::string { - if (j.contains("choices") && !j["choices"].empty()) { - auto& d = j["choices"][0]["delta"]; - if (d.contains("content") && d["content"].is_string()) return d["content"]; - } - if (j.contains("content") && j["content"].is_string()) return j["content"]; - return ""; - }; - - // FIXED: Added type checks to prevent string-to-int conversion errors - auto get_token_id = [](const json& j) -> int { - if (j.contains("__raw_token_id") && j["__raw_token_id"].is_number()) return j["__raw_token_id"]; - if (j.contains("token") && j["token"].is_number()) return j["token"]; - if (j.contains("id") && j["id"].is_number()) return j["id"]; - return -1; - }; - - auto to_lower_str = [](std::string s) { - std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); - return s; - }; - - auto print_debug_buffer = [&](const std::deque& buf) { - std::cout << "Debug TokenBuffer (Size " << BUFFER_SIZE << "): ["; - size_t print_len = std::max(buf.size(), BUFFER_SIZE); - for (size_t i = 0; i < print_len; ++i) { - if (i < buf.size()) { - std::string content = get_content_str(buf[i]); - std::string escaped; - for (char c : content) { - if (c == '\n') escaped += "\\n"; - else if (c == '"') escaped += "\\\""; - else escaped += c; - } - std::cout << "\"" << escaped << "\""; - } else { - std::cout << "\"\""; - } - if (i < print_len - 1) std::cout << ", "; - } - std::cout << "]" << std::endl; - }; - - while (true) { - *active_task_id = current_task_id; - - if (!sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - std::unordered_set ids = { current_task_id }; - server_task_result_ptr result = nullptr; - while (!result) { - result = ctx_server.queue_results.recv_with_timeout(ids, 1); - if (!result && !sink.is_writable()) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - } - - std::vector items_to_buffer; - if (!result->is_error()) { - std::vector parts; - if (oaicompat) parts = format_partial_response_oaicompat(*result, completion_id); - else parts.push_back(result->data); - - json raw_data = result->data; - for (const auto& r : parts) { - json item = r; - if (raw_data.contains("token")) item["__raw_token_id"] = raw_data["token"]; - items_to_buffer.push_back(item); - } - } else { - items_to_buffer.push_back(result->to_json()); - } - - for (const auto& item : items_to_buffer) token_buffer.push_back(item); - - print_debug_buffer(token_buffer); - - // --- Check for Bans --- - std::string buffer_text = ""; - std::vector token_offsets; - for (const auto& item : token_buffer) { - token_offsets.push_back(buffer_text.length()); - buffer_text += get_content_str(item); - } - - std::string buffer_lower = to_lower_str(buffer_text); - size_t match_pos = std::string::npos; - std::string detected_phrase = ""; - - for (const auto& phrase : stop_phrases) { - size_t pos = buffer_lower.find(to_lower_str(phrase)); - if (pos != std::string::npos && (match_pos == std::string::npos || pos < match_pos)) { - match_pos = pos; - detected_phrase = phrase; - } - } - for (size_t i = 0; i < stop_regexes.size(); ++i) { - std::smatch match; - if (std::regex_search(buffer_text, match, stop_regexes[i])) { - size_t pos = match.position(0); - if (match_pos == std::string::npos || pos < match_pos) { - match_pos = pos; - detected_phrase = "REGEX:" + regex_patterns[i]; - } - } - } - - if (match_pos != std::string::npos) { - std::cout << "Debug: Stop phrase '" << detected_phrase << "' detected. Initiating ban logic." << std::endl; - - size_t split_index = 0; - bool found_split = false; - for (size_t i = 0; i < token_offsets.size(); ++i) { - if (token_offsets[i] + get_content_str(token_buffer[i]).length() > match_pos) { - split_index = i; - found_split = true; - break; - } - } - - if (found_split) { - std::string temp_prompt_suffix = ""; - std::deque good_tokens; - for (size_t i = 0; i < split_index; ++i) { - json& item = token_buffer[i]; - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - temp_prompt_suffix += get_content_str(item); - good_tokens.push_back(item); - } - - json& guilty_item = token_buffer[split_index]; - int guilty_token_id = get_token_id(guilty_item); - if (guilty_token_id == -1) { - auto t = ctx_server.tokenize(get_content_str(guilty_item), false); - if (!t.empty()) guilty_token_id = t[0]; - } - - if (guilty_token_id != -1) { - if (ban_slot_index != (int)split_index) { - current_step_bans.clear(); - ban_slot_index = (int)split_index; - } - current_step_bans.insert(guilty_token_id); - std::cout << "Debug: Banning token ID " << guilty_token_id << " at slot " << split_index << ". Total bans: " << current_step_bans.size() << std::endl; - - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - - // Generate Fix Token - json fix_data = data; - fix_data["prompt"] = current_prompt_str + temp_prompt_suffix; - fix_data["n_predict"] = 1; - if (!fix_data.contains("logit_bias")) fix_data["logit_bias"] = json::array(); - - if (fix_data["logit_bias"].is_array()) { - for (int banned_id : current_step_bans) fix_data["logit_bias"].push_back(json::array({ banned_id, ban_bias })); - } else if (fix_data["logit_bias"].is_object()) { - for (int banned_id : current_step_bans) fix_data["logit_bias"][std::to_string(banned_id)] = ban_bias; - } - - std::cout << "Debug: Fix Data Logit Bias: " << fix_data["logit_bias"].dump() << std::endl; - - int id_fix = ctx_server.queue_tasks.get_new_id(); - *active_task_id = id_fix; - ctx_server.queue_results.add_waiting_task_id(id_fix); - - auto fix_inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, fix_data["prompt"], true, true); - post_single_task(id_fix, fix_data, fix_inputs[0]); - - std::unordered_set fix_ids = { id_fix }; - server_task_result_ptr fix_result = nullptr; - while (!fix_result) { - fix_result = ctx_server.queue_results.recv_with_timeout(fix_ids, 1); - if (!fix_result && !sink.is_writable()) { - request_cancel(id_fix); - ctx_server.queue_results.remove_waiting_task_id(id_fix); - return false; - } - } - ctx_server.queue_results.remove_waiting_task_id(id_fix); - - if (fix_result->is_error()) { - std::cout << "Debug: Fix task failed with error." << std::endl; - send_sse(sink, fix_result->to_json()); - return false; - } - - json fix_token_json; - if (oaicompat) { - auto parts = format_partial_response_oaicompat(*fix_result, completion_id); - if (!parts.empty()) fix_token_json = parts[0]; - } else { - fix_token_json = fix_result->data; - } - if (fix_result->data.contains("token")) fix_token_json["__raw_token_id"] = fix_result->data["token"]; - - // Resume Generation - json resume_data = data; - bool stop_after_fix = false; - if (original_n_predict > 0) { - int pending = good_tokens.size() + 1; - if (total_tokens_streamed + pending >= original_n_predict) stop_after_fix = true; - else resume_data["n_predict"] = original_n_predict - (total_tokens_streamed + pending); - } - - if (stop_after_fix) { - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - while (!token_buffer.empty()) { - json& item = token_buffer.front(); - if (item.contains("__raw_token_id")) item.erase("__raw_token_id"); - if (!send_sse(sink, item)) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return false; - } - total_tokens_streamed++; - token_buffer.pop_front(); - } - successful_completion = true; - goto cleanup; - } - - resume_data["prompt"] = current_prompt_str + temp_prompt_suffix + get_content_str(fix_token_json); - current_task_id = ctx_server.queue_tasks.get_new_id(); - *active_task_id = current_task_id; - ctx_server.queue_results.add_waiting_task_id(current_task_id); - - auto resume_inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, resume_data["prompt"], true, true); - post_single_task(current_task_id, resume_data, resume_inputs[0]); - - token_buffer = good_tokens; - token_buffer.push_back(fix_token_json); - } - } - } - - // Flush Logic - bool should_flush_all = result->is_stop() || result->is_error(); - if (token_buffer.size() >= BUFFER_SIZE || should_flush_all) { - while (!token_buffer.empty()) { - if (!should_flush_all && token_buffer.size() < BUFFER_SIZE) break; - - json& item_to_send = token_buffer.front(); - if (item_to_send.contains("__raw_token_id")) item_to_send.erase("__raw_token_id"); - - current_prompt_str += get_content_str(item_to_send); - - if (ban_slot_index != -1) { - if (0 == ban_slot_index) { - current_step_bans.clear(); - ban_slot_index = -1; - } else { - ban_slot_index--; - } - } - - if (!send_sse(sink, item_to_send)) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - - total_tokens_streamed++; - token_buffer.pop_front(); - - if (original_n_predict > 0 && total_tokens_streamed >= original_n_predict) { - request_cancel(current_task_id); - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - successful_completion = true; - goto cleanup; - } - } - } - - if (result->is_error()) { - ctx_server.queue_results.remove_waiting_task_id(current_task_id); - return false; - } - if (result->is_stop()) { - successful_completion = true; - break; - } - } - - cleanup: - bool ok = true; - if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string done_message = "data: [DONE]\n\n"; - LOG_VERBOSE("data stream", { {"to_send", done_message} }); - if (!sink.write(done_message.c_str(), done_message.size())) ok = false; - } - sink.done(); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - return ok; - - } catch (const std::exception& e) { - std::cerr << "Exception in streaming handler: " << e.what() << std::endl; - send_sse(sink, json{ {"error", {{"message", e.what()}, {"type", "server_error"}, {"code", 500}}} }); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } catch (...) { - std::cerr << "Unknown exception in streaming handler" << std::endl; - send_sse(sink, json{ {"error", {{"message", "Unknown error"}, {"type", "server_error"}, {"code", 500}}} }); - sink.done(); - if (active_task_id) { - request_cancel(*active_task_id); - ctx_server.queue_results.remove_waiting_task_id(*active_task_id); - } - return false; - } - }; - - auto on_complete = [active_task_id, &ctx_server, request_cancel](bool) { - int id_to_cancel = *active_task_id; - request_cancel(id_to_cancel); - ctx_server.queue_results.remove_waiting_task_id(id_to_cancel); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); -} - -// --- END OF HELPER FUNCTIONS --- - - - - - int main(int argc, char ** argv) { #if SERVER_VERBOSE != 1 log_disable(); @@ -1554,10 +1038,9 @@ int main(int argc, char ** argv) { }; - // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results -const auto handle_completions_impl = [&ctx_server, ¶ms]( + const auto handle_completions_impl = [&ctx_server, ¶ms]( server_task_type type, json& data, const std::vector& files, @@ -1566,91 +1049,62 @@ const auto handle_completions_impl = [&ctx_server, ¶ms]( oaicompat_type oaicompat) -> void { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - // 1. Common Validation - std::string invalid_re = validate_banned_regex(data); - if (!invalid_re.empty()) { - res_err(res, format_error_response("Invalid regex: " + invalid_re, ERROR_TYPE_INVALID_REQUEST)); - return; - } - const auto completion_id = gen_chatcmplid(); - std::string model_name = get_model_name(ctx_server.params_base.model); - - auto list_has_content = [&](const std::string& key) { - if (data.contains(key) && data[key].is_array()) { - for (const auto& item : data[key]) { - if (item.is_string() && !item.get().empty()) { - return true; - } - } - } - return false; - }; - - bool has_banned_content = list_has_content("banned_strings") || - list_has_content("banned_regex") || - list_has_content("banned_regex_case_insensitive"); + // need to store the reader as a pointer, so that it won't be destroyed when the handle returns + // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() + const auto rd = std::make_shared(ctx_server); - // 2. Common Prompt Processing - std::vector inputs; try { + std::vector tasks; + const auto& prompt = data.at("prompt"); + + // process prompt + std::vector inputs; + if (oaicompat && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); } else { + // Everything else, including multimodal completions. inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.data = data; + //task.params = server_task::params_from_json_cmpl( + // ctx_server.ctx, + // ctx_server.params, + // data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = get_model_name(ctx_server.params_base.model); + tasks.push_back(std::move(task)); + } + + rd->post_tasks(std::move(tasks)); } catch (const std::exception& e) { res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); return; } - - // 3. Branching Logic - if (has_banned_content) { - // PATH B: Complex Logic (Delegated to external function) - handle_completions_banned_impl(ctx_server, type, data, std::move(inputs[0]), is_connection_closed, res, oaicompat, completion_id); - return; - } - - // PATH A: Standard Logic (Inline) - auto create_task = [type, oaicompat, completion_id, model_name](int id, int index, server_tokens tokens, json& task_data) { - server_task task(type); - task.id = id; - task.index = index; - task.tokens = std::move(tokens); - task.data = task_data; - task.id_slot = json_value(task_data, "id_slot", -1); - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = model_name; - return task; - }; - - auto send_sse = [oaicompat](httplib::DataSink& sink, const json& payload) -> bool { - if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { - return server_sent_anthropic_event(sink, payload); - } else { - return server_sent_event(sink, payload); - } - }; - bool stream = json_value(data, "stream", false); - const auto rd = std::make_shared(ctx_server); - std::vector tasks; - tasks.reserve(inputs.size()); - - for (size_t i = 0; i < inputs.size(); i++) { - tasks.push_back(create_task(ctx_server.queue_tasks.get_new_id(), i, std::move(inputs[i]), data)); - } - rd->post_tasks(std::move(tasks)); - if (!stream) { + // non-stream, wait for the results auto all_results = rd->wait_for_all(is_connection_closed); if (all_results.is_terminated) { - llama_decode_stop(); - return; + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed } else if (all_results.error) { res_err(res, all_results.error->to_json()); @@ -1660,90 +1114,100 @@ const auto handle_completions_impl = [&ctx_server, ¶ms]( json arr = json::array(); for (auto& res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - if (oaicompat) { - arr.push_back(format_final_response_oaicompat(data, res->data, completion_id, false)); - } else { - arr.push_back(res->to_json()); - } + arr.push_back(res->to_json()); } + // if single request, return single object instead of array res_ok(res, arr.size() == 1 ? arr[0] : arr); } } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 server_task_result_ptr first_result = rd->next(is_connection_closed); if (first_result == nullptr) { - llama_decode_stop(); - return; + llama_decode_stop(); // send a signal to stop decode process + return; // connection is closed } else if (first_result->is_error()) { res_err(res, first_result->to_json()); return; } - - std::vector first_result_parts; - if (oaicompat) { - first_result_parts = format_partial_response_oaicompat(*first_result, completion_id); - } else { - first_result_parts.push_back(first_result->to_json()); + else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); } - - const auto chunked_content_provider = [first_result_parts, rd, oaicompat, send_sse, completion_id](size_t, httplib::DataSink& sink) mutable -> bool { - for (auto& part : first_result_parts) { - if (!part.empty()) { - if (!send_sse(sink, part)) { - sink.done(); - return false; - } - part.clear(); + // next responses are streamed + json first_result_json = first_result->to_json(); + const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink& sink) mutable -> bool { + const auto sse = [oaicompat, &sink](const json& res) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + return server_sent_anthropic_event(sink, res); + } + else { + return server_sent_event(sink, res); + } + }; + // flush the first result as it's not an error + if (!first_result_json.empty()) { + if (!sse(first_result_json)) { + sink.done(); + return false; // sending failed, go to on_complete() } + first_result_json.clear(); // mark as sent } + // receive subsequent results auto result = rd->next([&sink] { return !sink.is_writable(); }); if (result == nullptr) { sink.done(); - return false; + return false; // connection is closed, go to on_complete() } + // send the results + json res_json = result->to_json(); + bool ok = false; if (result->is_error()) { - send_sse(sink, json{ { "error", result->to_json() } }); + ok = sse(json{ { "error", result->to_json() } }); sink.done(); - return false; + return false; // go to on_complete() } - - if (oaicompat) { - std::vector parts = format_partial_response_oaicompat(*result, completion_id); - for (const auto& part : parts) { - if (!send_sse(sink, part)) { - sink.done(); - return false; - } - } - } else { - if (!send_sse(sink, result->to_json())) { - sink.done(); - return false; - } + else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + ok = sse(res_json); + } + + if (!ok) { + sink.done(); + return false; // sending failed, go to on_complete() } + // check if there is more data if (!rd->has_next()) { if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } sink.done(); - return false; + return false; // no more data, go to on_complete() } + + // has next data, continue return true; }; - auto on_complete = [rd](bool) { rd->stop(); }; + auto on_complete = [rd](bool) { + rd->stop(); + }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }; - - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { auto data = json::parse(req.body); std::vector files; // dummy From e9cd2549c35f63aa0b27a06eb6c4787f98aba2f4 Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Fri, 6 Feb 2026 05:35:57 +0100 Subject: [PATCH 7/8] Update2 --- common/common.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d60aef737..9002a8f9e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -2651,14 +2652,10 @@ std::string string_get_sortable_timestamp() { return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); } -// could be improved to support more languages std::string string_lower(const std::string& str) { std::string result = str; - for (char& c : result) { - if (c >= 'A' && c <= 'Z') { - c = static_cast(c + ('a' - 'A')); - } - } + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return std::tolower(c); }); return result; } From ccf2cef0b7604d055c8fac7d9944ee714ef9894e Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Fri, 6 Feb 2026 08:47:22 +0100 Subject: [PATCH 8/8] Fix attempt #1 Needs testing --- examples/server/server-context.cpp | 191 +++++++++++++++++++---------- examples/server/server-context.h | 1 + 2 files changed, 128 insertions(+), 64 deletions(-) diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 5b97127ad..0f2b677e1 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -335,6 +335,9 @@ void server_slot::reset() { generated_token_probs.clear(); + // --- FIX: Clear positional bans on reset --- + positional_bans.clear(); + // ------------------------------------------ // Reset speculative decoding stats n_draft_total = 0; @@ -3011,15 +3014,21 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_ return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end(); }; - void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; bool released = false; + + // FIX: Restored the +1 to match check_ban_phrase + int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; + for (auto& it : results) { bool has_next = process_token(it, slot); + + // Clean up positional bans for the token we just confirmed/sent + slot.positional_bans.erase(start_pos + count); + count++; if (!has_next) { - // If stopped by limit, continue processing the buffer to ensure all generated tokens are sent if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { continue; } @@ -3035,7 +3044,6 @@ void server_context::send_token_results(completion_token_outputs& results, serve } } - // If we finished the loop due to limit (and didn't release yet), release now if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { slot.release(); slot.print_timings(); @@ -3047,25 +3055,26 @@ void server_context::send_token_results(completion_token_outputs& results, serve slot.sampled = results[results.size()-1].tok; results.erase(results.begin(), results.begin() + count); } - } -inline int32_t check_ban_phrase(const server_slot& slot) { - bool found = false; - size_t n = slot.token_buffer.size(); - size_t best_start = std::string::npos; - int32_t n_rewind = 0; +inline int32_t check_ban_phrase(server_slot& slot) { + if (slot.token_buffer.empty()) return 0; + std::string string_buffer; - llama_tokens tokens; - for (auto& it : slot.token_buffer) { - string_buffer = string_buffer + it.text_to_send; - tokens.push_back(it.tok); + std::vector token_offsets; + + for (const auto& it : slot.token_buffer) { + token_offsets.push_back(string_buffer.size()); + string_buffer += it.text_to_send; } + + size_t best_start = std::string::npos; + bool found = false; std::string string_buffer_lower = string_lower(string_buffer); - - // Check strings - for (auto it : slot.ban_phrases) { - size_t start = string_buffer_lower.find(it); + + // 1. Check strings + for (const auto& phrase : slot.ban_phrases) { + size_t start = string_buffer_lower.find(phrase); if (start != std::string::npos) { if (start < best_start) { best_start = start; @@ -3074,7 +3083,7 @@ inline int32_t check_ban_phrase(const server_slot& slot) { } } - // Check regex + // 2. Check regex for (const auto& pattern : slot.ban_regex) { try { std::regex re(pattern); @@ -3085,13 +3094,10 @@ inline int32_t check_ban_phrase(const server_slot& slot) { found = true; } } - } catch (const std::regex_error& e) { - // Should be caught during validation, but safe fallback - continue; - } + } catch (...) { continue; } } - // Check regex case insensitive + // 3. Check regex case insensitive for (const auto& pattern : slot.ban_regex_ci) { try { std::regex re(pattern, std::regex_constants::icase); @@ -3102,48 +3108,74 @@ inline int32_t check_ban_phrase(const server_slot& slot) { found = true; } } - } catch (const std::regex_error& e) { - continue; - } + } catch (...) { continue; } } if (found) { - std::vector unused; - LLAMA_LOG_DEBUG("Banned string/regex dectected: %s\n ", string_buffer.substr(best_start).c_str()); - n = find_n_tokens_from_string(slot.ctx, tokens, best_start, 0, unused); - n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n; + int32_t token_idx = -1; + for (size_t i = 0; i < token_offsets.size(); ++i) { + size_t len = (i == token_offsets.size() - 1) + ? string_buffer.size() - token_offsets[i] + : token_offsets[i+1] - token_offsets[i]; + + if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) { + token_idx = (int32_t)i; + break; + } + } + + if (token_idx != -1) { + // FIX: Restored the +1 as requested. + // n_past is the index of the next token. + // If buffer has 1 token, it is at n_past - 1. + // With +1: n_past - 1 + 1 = n_past. + // This bans the token at the current n_past index (the one we just generated). + int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx; + llama_token banned_tok = slot.token_buffer[token_idx].tok; + + LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", + abs_pos, banned_tok, slot.token_buffer[token_idx].text_to_send.c_str()); + + slot.positional_bans[abs_pos].insert(banned_tok); + + return (int32_t)slot.token_buffer.size() - token_idx; + } } - return n_rewind; + + return 0; } inline void rewind_context(server_slot& slot, int32_t n_rewind) { slot.rewind_count++; - int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind; - std::set tokens; - // ban all tokens for better coherence - if (slot.banned_n != 0) { - int32_t n = 0; - for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++) - { - if (!tokens.contains(result->tok)) { - slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; - } - else { - tokens.insert(result->tok); - } - n++; - if (slot.banned_n > 0 && n == slot.banned_n) { - break; - } - } + + size_t n_remove = n_rewind; + if (n_remove > slot.cache_tokens.size()) { + n_remove = slot.cache_tokens.size(); + } + + size_t n_keep = slot.cache_tokens.size() - n_remove; + + // Set sampled to the token we are "keeping" at the end, so it gets re-added + // to the batch in the next update_slots cycle. + if (n_keep < slot.cache_tokens.size()) { + slot.sampled = slot.cache_tokens[n_keep]; + } else { + slot.sampled = 0; } - slot.token_buffer.resize(n_keep_rewind); - size_t n_keep = slot.cache_tokens.size() - n_rewind; - slot.sampled = slot.cache_tokens[n_keep]; + // Truncate cache slot.cache_tokens.keep_first(n_keep); + slot.n_past = slot.cache_tokens.n_tokens(); - // Adjust decoded count so user doesn't lose budget on banned tokens + // Remove from KV cache + llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1); + + // Truncate buffer + int32_t n_keep_buffer = (int32_t)slot.token_buffer.size() - n_rewind; + if (n_keep_buffer < 0) n_keep_buffer = 0; + slot.token_buffer.resize(n_keep_buffer); + + // Adjust decoded count slot.n_decoded -= n_rewind; if (slot.n_decoded < 0) slot.n_decoded = 0; } @@ -3152,38 +3184,39 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ slot.token_buffer.push_back(result); bool next_token = has_next_token(result, slot); - bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token; + // If buffer full or generation stopped, we might send tokens + bool buffer_full = slot.token_buffer.size() >= slot.n_buffer; + int32_t n_rewind = 0; - // don't restore if last time was also rewind - if (!slot.rewind_status) { - slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias - } + if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) { n_rewind = check_ban_phrase(slot); } - // if found string in the ban + if (n_rewind > 0) { rewind_context(slot, n_rewind); slot.rewind_status = true; } - else if (send_result) { + else if (buffer_full || !next_token) { slot.rewind_status = false; slot.rewind_count = 0; + if (!next_token) { - // send all remaining tokens in the buffer + // send all remaining tokens send_token_results(slot.token_buffer, slot); } else { - // send 1 token + // send 1 token from the front (FIFO) send_token_results(slot.token_buffer, slot, 1); } } else { - // buffer the result - slot.sampled = result.tok; // for common batch add + // buffer the result, wait for more tokens to validate string + slot.sampled = result.tok; } } + void server_context::process_batch_tokens(int32_t & n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -3259,8 +3292,38 @@ void server_context::process_batch_tokens(int32_t & n_batch) { completion_token_output result; const int tok_idx = slot.i_batch - i; + + // --- START POSITIONAL BAN LOGIC --- + // Check if we have specific bans for this exact position (slot.n_past) + // Note: slot.n_past is the index of the token we are about to generate. + auto pos_ban_it = slot.positional_bans.find(slot.n_past); + std::vector temp_banned; + + if (pos_ban_it != slot.positional_bans.end()) { + for (llama_token banned_tok : pos_ban_it->second) { + // Only ban if not already banned by user to avoid overwriting -INF + if (slot.ctx_sampling->params.logit_bias.find(banned_tok) == slot.ctx_sampling->params.logit_bias.end() || + slot.ctx_sampling->params.logit_bias[banned_tok] > -1000.0f) { + + slot.ctx_sampling->params.logit_bias[banned_tok] = -INFINITY; + temp_banned.push_back(banned_tok); + } + } + } + // --- END POSITIONAL BAN LOGIC --- + const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx); + // --- RESTORE LOGIT BIAS --- + for (llama_token banned_tok : temp_banned) { + if (slot.logit_bias.count(banned_tok)) { + slot.ctx_sampling->params.logit_bias[banned_tok] = slot.logit_bias[banned_tok]; + } else { + slot.ctx_sampling->params.logit_bias.erase(banned_tok); + } + } + // --- END RESTORE --- + common_sampler_accept(slot.ctx_sampling, ctx, id, true); slot.n_decoded += 1; diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 088e3464b..ea28347bb 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -94,6 +94,7 @@ struct server_slot { completion_token_outputs token_buffer; float ban_phrases_bias = 0; int32_t banned_n = 1; + std::map> positional_bans; server_prompt server_cached_prompt;