From df600c1ec0d8e396f40f1cff654c761ddc46dbbd Mon Sep 17 00:00:00 2001 From: Saood Karim Date: Fri, 6 Jun 2025 12:23:07 -0500 Subject: [PATCH] Add list_saved_prompts function to server --- examples/server/server.cpp | 45 +++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 466bb339c..8724e8d8b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3390,6 +3390,48 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json response = json::array(); + namespace fs = std::filesystem; + + try { + for (const auto& entry : fs::directory_iterator(params.slot_save_path)) { + if (!entry.is_regular_file() || entry.file_size() < 12) { + continue; + } + + std::ifstream file(entry.path(), std::ios::binary); + if (!file) continue; + + uint32_t magic, version, n_token_count; + file.read(reinterpret_cast(&magic), sizeof(magic)); + file.read(reinterpret_cast(&version), sizeof(version)); + file.read(reinterpret_cast(&n_token_count), sizeof(n_token_count)); + + if (magic != LLAMA_STATE_SEQ_MAGIC || + version != LLAMA_STATE_SEQ_VERSION || + entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) { + continue; + } + + std::vector tokens(n_token_count); + file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); + + response.push_back({ + {"filename", entry.path().filename().string()}, + {"filesize", entry.file_size()}, + {"token_count", n_token_count}, + {"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())} + }); + } + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(content), len, mime_type); @@ -3448,8 +3490,9 @@ int main(int argc, char ** argv) { // Save & load slots svr->Get ("/slots", handle_slots); if (!params.slot_save_path.empty()) { - // only enable slot endpoints if slot_save_path is set + // these endpoints rely on slot_save_path existing svr->Post("/slots/:id_slot", handle_slots_action); + svr->Get ("/list", list_saved_prompts); } //