Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions common/arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,3 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e

// function to be used by test-arg-parser
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

struct common_remote_params {
std::vector<std::string> headers;
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
200 changes: 139 additions & 61 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
Expand All @@ -325,6 +326,11 @@ static bool common_download_file_single_online(const std::string & url,
common_load_model_from_url_headers headers;
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
curl_slist_ptr http_headers;

for (const auto & h : custom_headers) {
std::string s = h.first + ": " + h.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
}
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
if (!was_perform_successful) {
head_request_ok = false;
Expand Down Expand Up @@ -449,8 +455,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
}
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");

for (const auto & header : params.headers) {
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
std::string header_ = header.first + ": " + header.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);

Expand Down Expand Up @@ -562,19 +570,64 @@ static bool common_pull_file(httplib::Client & cli,
return true;
}

static void common_set_clean_host_header(httplib::Headers & headers, const std::string & host) {
if (headers.count("Host")) {
headers.erase("Host");
}

std::string clean_host = host;
size_t pos = clean_host.find(':');
if (pos != std::string::npos) {
clean_host = clean_host.substr(0, pos);
}

headers.emplace("Host", clean_host);
}

static void common_resolve_redirects(std::string & url, httplib::Headers & headers) {
for (int r = 0; r < 5; ++r) {
auto [cli, parts] = common_http_client(url);
cli.set_follow_location(false);
common_set_clean_host_header(headers, parts.host);

httplib::Headers probe_headers = headers;
probe_headers.emplace("Range", "bytes=0-0");

auto head = cli.Get(parts.path, probe_headers);

if (head && (head->status >= 300 && head->status < 400) && head->has_header("Location")) {
url = head->get_header_value("Location");
if (headers.count("Authorization")) {
headers.erase("Authorization");
}
continue;
}
break;
}
auto parts = common_http_parse_url(url);
common_set_clean_host_header(headers, parts.host);
}

// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;

auto [cli, parts] = common_http_client(url);

httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) {
default_headers.insert({"Authorization", "Bearer " + bearer_token});
}
for (const auto & h : custom_headers) {
default_headers.emplace(h.first, h.second);
}

std::string real_url = url;
common_resolve_redirects(real_url, default_headers);

auto [cli, parts] = common_http_client(real_url);
cli.set_default_headers(default_headers);

const bool file_exists = std::filesystem::exists(path);
Expand All @@ -589,7 +642,9 @@ static bool common_download_file_single_online(const std::string & url,
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
bool head_403 = head && head->status == 403;

if (!head_ok && !head_403) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
Expand All @@ -598,22 +653,26 @@ static bool common_download_file_single_online(const std::string & url,
}

std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}

size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}

bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";

if (head_ok) {
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
} else if (head_403) {
LOG_INF("%s: 403 on HEAD, assuming GET/Resume is allowed\n", __func__);
supports_ranges = true;
}

bool should_download_from_scratch = false;
Expand Down Expand Up @@ -680,13 +739,9 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
auto [cli, parts] = common_http_client(url);

httplib::Headers headers = {{"User-Agent", "llama-cpp"}};

for (const auto & header : params.headers) {
size_t pos = header.find(':');
if (pos != std::string::npos) {
headers.emplace(header.substr(0, pos), header.substr(pos + 1));
} else {
headers.emplace(header, "");
}
headers.emplace(header.first, header.second);
}

if (params.timeout > 0) {
Expand Down Expand Up @@ -718,9 +773,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
bool offline,
const common_header_list & headers) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token);
return common_download_file_single_online(url, path, bearer_token, headers);
}

if (!std::filesystem::exists(path)) {
Expand All @@ -734,13 +790,24 @@ static bool common_download_file_single(const std::string & url,

// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
futures_download.reserve(urls.size());

for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline);
}, item));
futures_download.push_back(
std::async(
std::launch::async,
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
},
item
)
);
}

// Wait for all downloads to complete
Expand All @@ -753,17 +820,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
return true;
}

bool common_download_model(
const common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool common_download_model(const common_params_model & model,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
}

if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
return false;
}

Expand Down Expand Up @@ -822,13 +889,16 @@ bool common_download_model(
}

// Download in parallel
common_download_file_multiple(urls, bearer_token, offline);
common_download_file_multiple(urls, bearer_token, offline, headers);
}

return true;
}

common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & custom_headers) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
Expand All @@ -839,10 +909,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;

// headers
std::vector<std::string> headers;
headers.push_back("Accept: application/json");
common_header_list headers = custom_headers;
headers.push_back({"Accept", "application/json"});
if (!bearer_token.empty()) {
headers.push_back("Authorization: Bearer " + bearer_token);
headers.push_back({"Authorization", "Bearer " + bearer_token});
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
// User-Agent header is already set in common_remote_get_content, no need to set it here
Expand Down Expand Up @@ -913,8 +983,14 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
// Docker registry functions
//

static std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
static std::string common_docker_get_token(const std::string & repo,
const common_oci_params & oci_params) {
if (oci_params.auth_url.empty()) {
return "";
}
std::string url = oci_params.auth_url
+ "?service=" + oci_params.auth_service
+ "&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);
Expand All @@ -933,7 +1009,7 @@ static std::string common_docker_get_token(const std::string & repo) {
return response["token"].get<std::string>();
}

std::string common_docker_resolve_model(const std::string & docker) {
std::string common_docker_resolve_model(const std::string & docker, const common_oci_params & params) {
// Parse ai/smollm2:135M-Q4_0
size_t colon_pos = docker.find(':');
std::string repo, tag;
Expand Down Expand Up @@ -970,16 +1046,20 @@ std::string common_docker_resolve_model(const std::string & docker) {
return normalized;
};

std::string token = common_docker_get_token(repo); // Get authentication token
std::string token = common_docker_get_token(repo, params); // Get authentication token

// Get manifest
// TODO: cache the manifest response so that it appears in the model list
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
const std::string url_prefix = params.registry_url + "/v2/" + repo;
std::string manifest_url = url_prefix + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");

if (!token.empty()) {
manifest_params.headers.push_back({"Authorization", "Bearer " + token});
}
manifest_params.headers.push_back({"Accept",
"application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
});
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
Expand All @@ -990,17 +1070,15 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
if (!layer.contains("mediaType") || !layer.contains("digest")) {
continue;
}
if (layer["mediaType"].get<std::string>() == params.media_type) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}
Expand All @@ -1016,7 +1094,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);

const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false)) {
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
throw std::runtime_error("Failed to download Docker Model");
}

Expand All @@ -1030,11 +1108,11 @@ std::string common_docker_resolve_model(const std::string & docker) {

#else

common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}

bool common_download_model(const common_params_model &, const std::string &, bool) {
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}

Expand Down
Loading
Loading