diff --git a/common/arg.cpp b/common/arg.cpp index 6751a55ab0c..710955a86fb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3072,6 +3072,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.models_max = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--models-memory-margin"}, "N", + string_format("for router server, MiB of memory to leave free, per device (default: %d, 0 = unlimited)", params.models_memory_margin), + [](common_params & params, int value) { + params.models_memory_margin = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MEMORY_MARGIN")); add_opt(common_arg( {"--models-autoload"}, {"--no-models-autoload"}, @@ -3301,6 +3308,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.offline = true; } ).set_env("LLAMA_OFFLINE")); + add_opt(common_arg( + {"--download-only"}, + "Download the model file(s) and exit", + [](common_params & params) { + params.download_only = true; + } + )); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" diff --git a/common/common.h b/common/common.h index 4137a87f1d2..066e5766502 100644 --- a/common/common.h +++ b/common/common.h @@ -482,6 +482,7 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + bool download_only = false; // only download the model if required, don't start the server int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -607,10 +608,11 @@ struct common_params { std::vector server_tools; // router server configs - std::string models_dir = ""; // directory containing models for the router server - std::string models_preset = ""; // directory containing model presets for the router server - int models_max = 4; // maximum number of models to load simultaneously - bool models_autoload = true; // automatically load models when requested via the router server + std::string models_dir = ""; // directory containing models for the router server + std::string models_preset = ""; // directory containing model presets for the router server + int models_max = 4; // maximum number of models to load simultaneously + int models_memory_margin = 1024; // MiB of free memory to preserve per device (0 = disabled) + bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8126249e143..79437bbd177 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3493,6 +3493,19 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } +uint64_t llama_context_device_memory(const llama_context * ctx, ggml_backend_dev_t device) { + const bool is_host = ggml_backend_dev_type(device) == GGML_BACKEND_DEVICE_TYPE_CPU; + uint64_t total = 0; + for (const auto & [buft, mb] : ctx->memory_breakdown()) { + const bool matches = is_host ? ggml_backend_buft_is_host(buft) : + ggml_backend_buft_get_device(buft) == device; + if (matches) { + total += mb.total(); + } + } + return total; +} + // // training // diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217cb..ce87fa32a4a 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,9 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// Returns the projected memory use (model + context + compute) in bytes +// for the given device within this context. Returns 0 if the device is not used. +LLAMA_API uint64_t llama_context_device_memory( + const struct llama_context * ctx, + ggml_backend_dev_t device); diff --git a/tools/server/server-http.h b/tools/server/server-http.h index 68ae2170cf6..42ea8a8e992 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -28,7 +28,13 @@ struct server_http_res { return next != nullptr; } - virtual ~server_http_res() = default; + std::function on_destroy = nullptr; + + virtual ~server_http_res() { + if (on_destroy) { + on_destroy(); + } + } }; // unique pointer, used by set_chunked_content_provider diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6066611f51c..379b01a4f03 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -8,6 +8,8 @@ #include // TODO: remove this once we use HTTP client from download.h #include +#include "../../src/llama-ext.h" + #include #include #include @@ -94,6 +96,7 @@ static void unset_reserved_args(common_preset & preset, bool unset_model_args) { preset.unset_option("LLAMA_API_KEY"); preset.unset_option("LLAMA_ARG_MODELS_DIR"); preset.unset_option("LLAMA_ARG_MODELS_MAX"); + preset.unset_option("LLAMA_ARG_MODELS_MEMORY_MARGIN"); preset.unset_option("LLAMA_ARG_MODELS_PRESET"); preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD"); if (unset_model_args) { @@ -177,9 +180,27 @@ server_models::server_models( bin_path = get_server_exec_path().string(); } catch (const std::exception & e) { bin_path = argv[0]; - LOG_WRN("failed to get server executable path: %s\n", e.what()); - LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]); + SRV_WRN("failed to get server executable path: %s\n", e.what()); + SRV_WRN("using original argv[0] as fallback: %s\n", argv[0]); + } + + const size_t memory_margin = (size_t) base_params.models_memory_margin * 1024 * 1024; + + if (memory_margin > 0) { + const size_t n_devs = ggml_backend_dev_count(); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + if (total > 0) { + const size_t available = (free > memory_margin) ? free - memory_margin : 0; + dmm_available[dev] = available; + SRV_DBG("device %s: available memory after margin=%zu MiB\n", + ggml_backend_dev_name(dev), available / (1024 * 1024)); + } + } } + load_models(); } @@ -295,16 +316,17 @@ void server_models::load_models() { // convert presets to server_model_meta and add to mapping for (const auto & preset : final_presets) { server_model_meta meta{ - /* preset */ preset.second, - /* name */ preset.first, - /* aliases */ {}, - /* tags */ {}, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* exit_code */ 0, - /* stop_timeout */ DEFAULT_STOP_TIMEOUT, + /* preset */ preset.second, + /* name */ preset.first, + /* aliases */ {}, + /* tags */ {}, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* memory_per_device */ {}, + /* args */ std::vector(), + /* exit_code */ 0, + /* stop_timeout */ DEFAULT_STOP_TIMEOUT, }; add_model(std::move(meta)); } @@ -495,49 +517,316 @@ std::vector server_models::get_all_meta() { return result; } -void server_models::unload_lru() { - if (base_params.models_max <= 0) { - return; // no limit - } - // remove one of the servers if we passed the models_max (least recently used - LRU) - std::string lru_model_name = ""; - int64_t lru_last_used = ggml_time_ms(); - size_t count_active = 0; +void server_models::inc_refs(const std::string & name) { + std::lock_guard lk(mutex); + mapping[name].active_refs++; +} + +void server_models::dec_refs(const std::string & name) { { - std::unique_lock lk(mutex); - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; - if (m.second.meta.last_used < lru_last_used) { - lru_model_name = m.first; - lru_last_used = m.second.meta.last_used; - } + std::lock_guard lk(mutex); + mapping[name].active_refs--; + } + cv.notify_all(); +} + +int server_models::can_fit(const device_memory_map & dmm_req) const { + device_memory_map dmm_total; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + for (const auto & [dev, mem] : m.second.meta.dmm_req) { + dmm_total[dev] += mem; } } } - if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { - SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); - unload(lru_model_name); - // wait for unload to complete + + auto get = [](const device_memory_map & dmm, ggml_backend_dev_t dev) { + auto it = dmm.find(dev); + return it != dmm.end() ? it->second : 0; + }; + + int res = 0; + + for (const auto & [dev, limit] : dmm_available) { + const size_t mem_total = get(dmm_total, dev); + const size_t mem_new = get(dmm_req, dev); + + SRV_DBG("device %s: total=%zu MiB, new=%zu MiB, limit=%zu MiB\n", + ggml_backend_dev_name(dev), + mem_total / (1024 * 1024), mem_new / (1024 * 1024), limit / (1024 * 1024)); + + if (mem_total + mem_new > limit) { + res++; + } + } + + return res; +} + +void server_models::unload_lru(const device_memory_map & dmm_req) { + const bool check_active = base_params.models_max > 0; + const bool check_memory = base_params.models_memory_margin > 0; + + if (!check_active && !check_memory) { + return; // no limit + } + + if (check_memory) { + GGML_ASSERT(!dmm_available.empty()); + } + + while (true) { + std::string lru_model_name; + int64_t lru_last_used = ggml_time_ms(); + + int count_active = 0; + int count_exceed = 0; { std::unique_lock lk(mutex); - cv.wait(lk, [this, &lru_model_name]() { - return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + count_active++; + // Only consider idle models + if (m.second.active_refs == 0 && m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + if (check_memory) { + count_exceed = can_fit(dmm_req); + } + } + + const bool active_exceeded = check_active && count_active >= base_params.models_max; + const bool memory_exceeded = check_memory && count_exceed > 0; + + if (!lru_model_name.empty() && (active_exceeded || memory_exceeded)) { + SRV_INF("limits reached (count=%d, memory margin exceeded on %d device(s)), removing LRU name=%s\n", + count_active, count_exceed, lru_model_name.c_str()); + unload(lru_model_name); + // wait for unload to complete + { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &lru_model_name]() { + return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; + }); + } + } else if (count_active > 0 && (active_exceeded || memory_exceeded)) { + // No model idle, wait for drain + std::unique_lock lk(mutex); + bool drained = cv.wait_for(lk, std::chrono::seconds(DEFAULT_STOP_TIMEOUT), [this]() { + for (const auto & m : mapping) { + if (m.second.meta.is_running() && m.second.active_refs == 0) { + return true; + } + } + return false; }); + if (!drained) { + SRV_WRN("%s", "drain timeout, falling back to force eviction\n"); + break; + } + } else { + break; } } } +static std::string resolve_model_path(const common_preset & preset) { + common_params params; + preset.apply_to_params(params); + + if (!params.model.path.empty()) { + return params.model.path; + } + + if (!params.model.hf_repo.empty() || !params.model.url.empty()) { + common_download_opts opts; + opts.offline = true; + auto result = common_download_model(params.model, opts); + return result.model_path; + } + + return ""; +} + +static device_memory_map get_model_memory_per_device(const common_preset & preset) { + common_params params; + preset.apply_to_params(params); + + if(params.model.path.empty()) { + params.model.path = resolve_model_path(preset); + if(params.model.path.empty()) { + return {}; + } + } + + struct log_ud_t { + struct { + ggml_log_callback callback; + void * user_data; + } original; + ggml_log_level min_level; + } log_ud; + llama_log_get(&log_ud.original.callback, &log_ud.original.user_data); + log_ud.min_level = GGML_LOG_LEVEL_WARN; + + llama_log_set([](ggml_log_level level, const char * text, void * ud) { + log_ud_t * d = (log_ud_t *) ud; + const ggml_log_level eff = level >= d->min_level ? level : GGML_LOG_LEVEL_DEBUG; + d->original.callback(eff, text, d->original.user_data); + }, &log_ud); + + llama_model_params mparams = common_model_params_to_llama(params); + mparams.no_alloc = true; + mparams.use_mmap = false; + mparams.use_mlock = false; + + llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)}; + + if (!model) { + llama_log_set(log_ud.original.callback, log_ud.original.user_data); + return {}; + } + + llama_context_params cparams = common_context_params_to_llama(params); + llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)}; + llama_log_set(log_ud.original.callback, log_ud.original.user_data); + + if (!ctx) { + return {}; + } + + device_memory_map result; + const size_t n_devs = ggml_backend_dev_count(); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + uint64_t bytes = llama_context_device_memory(ctx.get(), dev); + if (bytes > 0) { + result[dev] = bytes; + } + } + + return result; +} + +bool server_models::download_model(const std::string & name) { + std::vector child_args; + std::vector child_env; + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + child_args = meta.preset.to_args(bin_path); + child_env = base_env; + } + child_args.push_back("--download-only"); + + SRV_INF("downloading model name=%s\n", name.c_str()); + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + subprocess_s proc; + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) { + SRV_ERR("failed to spawn download process for model name=%s\n", name.c_str()); + return false; + } + + FILE * out = subprocess_stdout(&proc); + if (out) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), out) != nullptr) { + LOG("[dl:%s] %s", name.c_str(), buffer); + } + } + + int exit_code = 0; + subprocess_join(&proc, &exit_code); + subprocess_destroy(&proc); + + if (exit_code != 0) { + SRV_ERR("download process for model name=%s exited with code %d\n", name.c_str(), exit_code); + return false; + } + + SRV_INF("download complete for model name=%s\n", name.c_str()); + return true; +} + void server_models::load(const std::string & name) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } - unload_lru(); + + { + common_preset preset_copy; + { + std::lock_guard lk(mutex); + preset_copy = mapping[name].meta.preset; + } + if (resolve_model_path(preset_copy).empty()) { + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + return; + } + meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + cv.notify_all(); + } + std::thread([this, name]() { + if (!download_model(name)) { + update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1); + return; + } + device_memory_map mem; + if (base_params.models_memory_margin > 0) { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + meta.dmm_req = get_model_memory_per_device(meta.preset); + if (meta.dmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } + mem = meta.dmm_req; + } + try { + _load(name, mem); + } catch (const std::exception & e) { + SRV_ERR("failed to load model %s after download: %s\n", name.c_str(), e.what()); + update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1); + } + }).detach(); + return; + } + } + + device_memory_map dmm_req; + if (base_params.models_memory_margin > 0) { + // determine the required memory by the model upon its first load + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + if (meta.dmm_req.empty()) { + meta.dmm_req = get_model_memory_per_device(meta.preset); + if (meta.dmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } + } + + dmm_req = meta.dmm_req; + } + + _load(name, dmm_req); +} + +void server_models::_load(const std::string & name, const device_memory_map & dmm_req) { + unload_lru(dmm_req); std::lock_guard lk(mutex); auto meta = mapping[name].meta; - if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + if (meta.status != SERVER_MODEL_STATUS_UNLOADED && meta.status != SERVER_MODEL_STATUS_DOWNLOADING) { SRV_INF("model %s is not ready\n", name.c_str()); return; } @@ -546,15 +835,24 @@ void server_models::load(const std::string & name) { // exceeding models_max. Without this, the window between unload_lru() // releasing its lock and this lock_guard acquiring allows multiple // threads to each observe capacity and all proceed to load. - if (base_params.models_max > 0) { - size_t count_active = 0; - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; + { + const bool check_active = base_params.models_max > 0; + const bool check_memory = base_params.models_memory_margin > 0; + + if (check_active || check_memory) { + int count_active = 0; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + count_active++; + } + } + + const bool active_exceeded = check_active && count_active >= base_params.models_max; + const bool memory_exceeded = check_memory && can_fit(dmm_req) > 0; + + if (active_exceeded || memory_exceeded) { + throw std::runtime_error("model limit reached, try again later"); } - } - if (count_active >= (size_t)base_params.models_max) { - throw std::runtime_error("model limit reached, try again later"); } } @@ -564,6 +862,7 @@ void server_models::load(const std::string & name) { inst.meta.port = get_free_port(); inst.meta.status = SERVER_MODEL_STATUS_LOADING; inst.meta.last_used = ggml_time_ms(); + inst.active_refs = mapping[name].active_refs; if (inst.meta.port <= 0) { throw std::runtime_error("failed to get a port number"); @@ -758,7 +1057,8 @@ void server_models::wait_until_loading_finished(const std::string & name) { cv.wait(lk, [this, &name]() { auto it = mapping.find(name); if (it != mapping.end()) { - return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING && + it->second.meta.status != SERVER_MODEL_STATUS_DOWNLOADING; } return false; }); @@ -898,10 +1198,18 @@ static bool router_validate_model(std::string & name, server_models & models, bo } // resolve alias to canonical model name name = meta->name; + // To avoid unloading a model before it is loaded, protect with increased ref count before it starts loading + models.inc_refs(name); if (models_autoload) { - models.ensure_model_ready(name); + try { + models.ensure_model_ready(name); + } catch (...) { + models.dec_refs(name); + throw; + } } else { if (!meta->is_running()) { + models.dec_refs(name); res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return false; } @@ -952,7 +1260,17 @@ void server_models_routes::init_routes() { if (!router_validate_model(name, models, autoload, error_res)) { return error_res; } - return models.proxy_request(req, method, name, false); + server_http_res_ptr proxy; + try { + proxy = models.proxy_request(req, method, name, false); + } catch(...) { + models.dec_refs(name); + throw; + } + proxy->on_destroy = [this, name]() { + this->models.dec_refs(name); + }; + return proxy; }; this->proxy_post = [this](const server_http_req & req) { @@ -964,7 +1282,17 @@ void server_models_routes::init_routes() { if (!router_validate_model(name, models, autoload, error_res)) { return error_res; } - return models.proxy_request(req, method, name, true); // update last usage for POST request only + server_http_res_ptr proxy; + try { + proxy = models.proxy_request(req, method, name, true); // update last usage for POST request only + } catch(...) { + models.dec_refs(name); + throw; + } + proxy->on_destroy = [this, name]() { + this->models.dec_refs(name); + }; + return proxy; }; this->post_router_models_load = [this](const server_http_req & req) { diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 1db34b6c4df..36cd0296f60 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -14,6 +14,9 @@ /** * state diagram: * + * + * ┌► DOWNLOADING ─┐ + * │ ▼ * UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING * ▲ │ │ ▲ * └───failed───┘ │ │ @@ -21,8 +24,8 @@ * └────────unloaded─────────┘ */ enum server_model_status { - // TODO: also add downloading state when the logic is added SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_DOWNLOADING, SERVER_MODEL_STATUS_LOADING, SERVER_MODEL_STATUS_LOADED, SERVER_MODEL_STATUS_SLEEPING @@ -32,6 +35,9 @@ static server_model_status server_model_status_from_string(const std::string & s if (status_str == "unloaded") { return SERVER_MODEL_STATUS_UNLOADED; } + if (status_str == "downloading") { + return SERVER_MODEL_STATUS_DOWNLOADING; + } if (status_str == "loading") { return SERVER_MODEL_STATUS_LOADING; } @@ -46,14 +52,17 @@ static server_model_status server_model_status_from_string(const std::string & s static std::string server_model_status_to_string(server_model_status status) { switch (status) { - case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; - case SERVER_MODEL_STATUS_LOADING: return "loading"; - case SERVER_MODEL_STATUS_LOADED: return "loaded"; - case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; - default: return "unknown"; + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; + default: return "unknown"; } } +using device_memory_map = std::map; + struct server_model_meta { common_preset preset; std::string name; @@ -62,6 +71,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading + device_memory_map dmm_req; // bytes required per device std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown @@ -90,6 +100,7 @@ struct server_models { std::thread th; server_model_meta meta; FILE * stdin_file = nullptr; + uint64_t active_refs = 0; }; std::mutex mutex; @@ -107,14 +118,28 @@ struct server_models { std::vector base_env; common_preset base_preset; // base preset from llama-server CLI args + // available memory per device + device_memory_map dmm_available; + void update_meta(const std::string & name, const server_model_meta & meta); // unload least recently used models if the limit is reached - void unload_lru(); + void unload_lru(const device_memory_map & dmm_req); // not thread-safe, caller must hold mutex void add_model(server_model_meta && meta); + // return number of devices where the memory limit would be exceeded + // return 0 if the new model would fit on all devices + // not thread-safe, caller must hold mutex + int can_fit(const device_memory_map & dmm_req) const; + + // download model files, blocking call (caller must NOT hold mutex) + bool download_model(const std::string & name); + + // Internal helper for model loading + void _load(const std::string & name, const device_memory_map & dmm_req); + public: server_models(const common_params & params, int argc, char ** argv); @@ -150,6 +175,12 @@ struct server_models { // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // Increase instance ref counter + void inc_refs(const std::string & name); + + // Decrease instance ref counter + void dec_refs(const std::string & name); + // return true if the current process is a child server instance static bool is_child_server(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6566949edf1..4ff962b89fc 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -83,6 +83,11 @@ int main(int argc, char ** argv) { return 1; } + if (params.download_only) { + LOG_INF("%s: model downloaded successfully, exiting\n", __func__); + return 0; + } + // validate batch size for embeddings // embeddings require all tokens to be processed in a single ubatch // see https://github.com/ggml-org/llama.cpp/issues/12836 diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 79e60db4083..d471ff88b55 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -1,4 +1,5 @@ import pytest +import threading from utils import * server: ServerProcess @@ -205,3 +206,126 @@ def test_router_api_key_required(): ) assert authed.status_code == 200 assert "error" not in authed.body + + +# --- Drain-aware eviction tests --- + + +def _make_completion(model_id: str, max_tokens: int = 16) -> dict: + """Send a non-streaming completion request. Returns {"content": ..., "error": ...}.""" + result = {"content": "", "error": None} + try: + res = server.make_request("POST", "/v1/chat/completions", data={ + "model": model_id, + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": "hi"}], + }) + if res.status_code == 200: + choices = res.body.get("choices", []) + if choices: + result["content"] = choices[0].get("message", {}).get("content", "") + else: + result["error"] = f"status {res.status_code}: {res.body}" + except Exception as e: + result["error"] = str(e) + return result + + +def test_router_concurrent_no_thrashing(): + """Concurrent requests for different models should all succeed, not thrash.""" + global server + server = ServerPreset.router() + server.models_max = 1 + server.start() + + model_a = "ggml-org/tinygemma3-GGUF:Q8_0" + model_b = "ggml-org/test-model-stories260K:F32" + n_per_model = 3 + results = {} + + def send_request(model_id, idx): + results[(model_id, idx)] = _make_completion(model_id) + + threads = [] + for i in range(n_per_model): + threads.append(threading.Thread(target=send_request, args=(model_a, i))) + threads.append(threading.Thread(target=send_request, args=(model_b, i))) + + for t in threads: + t.start() + for t in threads: + t.join(timeout=300) + + failures = [f"{m} #{i}: {r['error']}" for (m, i), r in results.items() if r["error"] is not None] + assert len(failures) == 0, f"{len(failures)} request(s) failed:\n" + "\n".join(failures) + + +def test_router_concurrent_partial_capacity(): + """With models_max=2 and 3 models, concurrent requests should all succeed.""" + global server + server = ServerPreset.router() + server.models_max = 2 + server.start() + + models = [ + "ggml-org/tinygemma3-GGUF:Q8_0", + "ggml-org/test-model-stories260K:F32", + "ggml-org/test-model-stories260K-infill:F32", + ] + results = {} + + def send_request(model_id, idx): + results[(model_id, idx)] = _make_completion(model_id) + + threads = [] + for model in models: + for i in range(2): + threads.append(threading.Thread(target=send_request, args=(model, i))) + + for t in threads: + t.start() + for t in threads: + t.join(timeout=300) + + failures = [f"{m} #{i}: {r['error']}" for (m, i), r in results.items() if r["error"] is not None] + assert len(failures) == 0, f"{len(failures)} request(s) failed:\n" + "\n".join(failures) + + +def test_router_alternating_requests(): + """Repeated alternating requests between two models should all succeed.""" + global server + server = ServerPreset.router() + server.models_max = 1 + server.start() + + model_a = "ggml-org/tinygemma3-GGUF:Q8_0" + model_b = "ggml-org/test-model-stories260K:F32" + + for i in range(3): + result = _make_completion(model_a) + assert result["error"] is None, f"Round {i} model A failed: {result['error']}" + result = _make_completion(model_b) + assert result["error"] is None, f"Round {i} model B failed: {result['error']}" + + +def test_router_concurrent_same_model(): + """Concurrent requests for the same model should all succeed.""" + global server + server = ServerPreset.router() + server.models_max = 1 + server.start() + + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + results = {} + + def send_request(idx): + results[idx] = _make_completion(model_id) + + threads = [threading.Thread(target=send_request, args=(i,)) for i in range(6)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=300) + + failures = [f"#{i}: {r['error']}" for i, r in results.items() if r["error"] is not None] + assert len(failures) == 0, f"{len(failures)} request(s) failed:\n" + "\n".join(failures)