Skip to content
21 changes: 16 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
opts.hf_prune_old_files = params.hf_prune_old_files;
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;

Expand Down Expand Up @@ -332,7 +333,8 @@ struct handle_model_result {

static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool hf_prune_old_files) {
handle_model_result result;

if (!model.docker_repo.empty()) {
Expand All @@ -347,6 +349,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
opts.hf_prune_old_files = hf_prune_old_files;
auto download_result = common_download_model(model, opts, true);

if (download_result.model_path.empty()) {
Expand All @@ -371,6 +374,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
opts.hf_prune_old_files = hf_prune_old_files;
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
Expand Down Expand Up @@ -577,7 +581,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context

// handle model and download
if (!skip_model_download) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, params.hf_prune_old_files);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -587,12 +591,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// only download mmproj if the current example is using it
for (const auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
common_params_handle_model(params.mmproj, params.hf_token, params.offline, params.hf_prune_old_files);
break;
}
}
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline, params.hf_prune_old_files);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline, params.hf_prune_old_files);
}

// model is required (except for server)
Expand Down Expand Up @@ -2649,6 +2653,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_token = value;
}
).set_env("HF_TOKEN"));
add_opt(common_arg(
{"-hfp", "--hf-prune-old-files"},
string_format("Keep only latest version of model files, delete old ones (default: %s)", params.hf_prune_old_files ? "true" : "false"),
[](common_params & params) {
params.hf_prune_old_files = true;
}
).set_env("LLAMA_ARG_HF_PRUNE_OLD_FILES"));
add_opt(common_arg(
{"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)",
Expand Down
7 changes: 4 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,10 @@ struct common_params {

struct common_params_model model;

std::set<std::string> model_alias; // model aliases // NOLINT
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
std::string hf_token = ""; // HF token // NOLINT
std::set<std::string> model_alias; // model aliases // NOLINT
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
std::string hf_token = ""; // HF token // NOLINT
bool hf_prune_old_files = false; // whether to keep only latest version of model files // NOLINT
std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT
Expand Down
5 changes: 5 additions & 0 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,11 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}

if (opts.hf_prune_old_files) {
auto hf_repo_with_tag = common_download_split_repo_tag(model.hf_repo);
hf_cache::prune_old_files(hf_repo_with_tag.first, hf.model_files, hf.mmproj);
}
} else {
result.model_path = model.path;
}
Expand Down
1 change: 1 addition & 0 deletions common/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct common_download_opts {
common_header_list headers;
bool offline = false;
common_download_callback * callback = nullptr;
bool hf_prune_old_files = false;
};

// Result of common_download_model
Expand Down
129 changes: 129 additions & 0 deletions common/hf-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,57 @@ hf_files get_cached_files(const std::string & repo_id) {
return files;
}

hf_files get_all_snapshot_files(const std::string & repo_id) {
fs::path cache_dir = get_cache_directory();
if (!fs::exists(cache_dir)) {
return {};
}

if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}

hf_files files;

for (const auto & repo : fs::directory_iterator(cache_dir)) {
if (!repo.is_directory()) {
continue;
}
fs::path snapshots_path = repo.path() / "snapshots";

if (!fs::exists(snapshots_path)) {
continue;
}
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());

if (!is_valid_repo_id(_repo_id)) {
continue;
}
if (!repo_id.empty() && _repo_id != repo_id) {
continue;
}

for (const auto & entry : fs::recursive_directory_iterator(snapshots_path)) {
if (!entry.is_regular_file() && !fs::is_directory(entry.path())) {
continue;
}
fs::path path = entry.path();

if (!path.empty()) {
hf_file file;
file.repo_id = _repo_id;
file.path = path.generic_string();
file.local_path = entry.path().string();
file.final_path = file.local_path;
files.push_back(std::move(file));
}
}
}

return files;
}

std::string finalize_file(const hf_file & file) {
static std::atomic<bool> symlinks_disabled{false};

Expand Down Expand Up @@ -502,6 +553,84 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}

void prune_old_files(const std::string & hf_repo, const hf_cache::hf_files & current_model_files, const hf_cache::hf_file & current_mmproj) {
std::vector<std::string> filenames_to_delete;
std::vector<std::string> files_to_keep;

const auto get_symlink_target = [&](const std::string & file) {
std::error_code ec;

const auto & parent = fs::path(file).parent_path();
const auto & target_relative = fs::read_symlink(file, ec);
if (ec) {
LOG_DBG("%s: failed to read symlink %s: %s\n", __func__, file.c_str(), ec.message().c_str());
return std::string();
}
const auto & target_unresolved = parent / target_relative;
const auto & target = fs::weakly_canonical(target_unresolved, ec);
if (ec) {
LOG_DBG("%s: failed to resolve symlink target %s: %s\n", __func__, file.c_str(), ec.message().c_str());
return std::string();
}
return std::string(target);
};

for (const auto & file : current_model_files) {
files_to_keep.push_back(file.local_path);
filenames_to_delete.push_back(fs::path(file.local_path).filename());
const auto & target = get_symlink_target(file.local_path);
if (!target.empty()) {
files_to_keep.push_back(target);
}
}

if (!current_mmproj.local_path.empty()) {
files_to_keep.push_back(current_mmproj.local_path);
filenames_to_delete.push_back(fs::path(current_mmproj.local_path).filename());
const auto & target = get_symlink_target(current_mmproj.local_path);
if (!target.empty()) {
files_to_keep.push_back(target);
}
}

const auto cached_snapshot_files = hf_cache::get_all_snapshot_files(hf_repo);
for (int i = cached_snapshot_files.size() - 1; i >= 0; --i) {
const auto & file_path = cached_snapshot_files[i].local_path;
if (std::find(files_to_keep.begin(), files_to_keep.end(), file_path) != files_to_keep.end()) {
continue;
}
std::error_code ec;
for (const auto & filename : filenames_to_delete) {
if (string_ends_with(file_path, filename) && fs::is_symlink(file_path)) {
const auto & commit = fs::path(file_path).parent_path();
const auto & blob_file = get_symlink_target(file_path);

if (!fs::remove(file_path.c_str(), ec)) {
LOG_ERR("%s: error deleting old symlink file from hf cache %s: %s\n", __func__, file_path.c_str(), ec.message().c_str());
return;
}

if (fs::is_empty(commit)) {
if (!fs::remove(commit.c_str(), ec)) {
LOG_ERR("%s: error deleting old commit directory from hf cache %s: %s\n", __func__, commit.c_str(), ec.message().c_str());
return;
}
}

if (!blob_file.empty() && std::find(files_to_keep.begin(), files_to_keep.end(), blob_file) == files_to_keep.end()) {
LOG_INF("deleting old blob file from hf cache: %s -> %s\n", file_path.c_str(), blob_file.c_str());
if (fs::exists(blob_file)) {
if (!fs::remove(blob_file.c_str(), ec)) {
LOG_ERR("%s: error deleting old hf blob file %s: %s\n", __func__, file_path.c_str(), ec.message().c_str());
return;
}
}
}
}
}
}
}

// delete everything after this line, one day

// copied from download.cpp without the tag part
Expand Down
3 changes: 3 additions & 0 deletions common/hf-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ hf_files get_repo_files(
);

hf_files get_cached_files(const std::string & repo_id = {});
hf_files get_all_snapshot_files(const std::string & repo_id = {});

// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);

void prune_old_files(const std::string & hf_repo, const hf_files & current_model_files, const hf_file & current_mmproj);

// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);

Expand Down