From e6c4319fc06f4cc15ff25c839954655da69bab87 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 11:59:51 +0200 Subject: [PATCH 1/3] common : add common_remote_get_content --- common/arg.cpp | 85 ++++++++++++++++++++++++++++++-------------------- common/arg.h | 3 ++ 2 files changed, 55 insertions(+), 33 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 0657553e4e9..dc3be7c8d9b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -527,52 +527,30 @@ static bool common_download_model( return true; } -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { - auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; - std::string hf_repo = parts[0]; - if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); - } - - // fetch model info from Hugging Face Hub API +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const std::vector & headers) { curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; - std::string res_str; - - std::string model_endpoint = get_model_endpoint(); + std::vector res_buffer; - std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag; curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); + auto data_vec = static_cast *>(data); + data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); return size * nmemb; }; curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); #if defined(_WIN32) curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - if (!bearer_token.empty()) { - std::string auth_header = "Authorization: Bearer " + bearer_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + for (const auto & header : headers) { + http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); CURLcode res = curl_easy_perform(curl.get()); @@ -582,9 +560,46 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ } long res_code; - std::string ggufFile = ""; - std::string mmprojFile = ""; curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + + return { res_code, res_buffer }; +} + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; + + // headers + std::vector headers; + headers.push_back("Accept: application/json"); + if (!bearer_token.empty()) { + headers.push_back("Authorization: Bearer " + bearer_token); + } + + // make the request + auto res = common_remote_get_content(url, headers); + long res_code = res.first; + std::string res_str(res.second.data(), res.second.size()); + std::string ggufFile; + std::string mmprojFile; + if (res_code == 200) { // extract ggufFile.rfilename in json, using regex { @@ -640,6 +655,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s return {}; } +std::pair> common_remote_get_content(const std::string & url, const std::vector & headers) { + throw std::runtime_error("error: built without CURL, cannot download model from the internet"); +} + #endif // LLAMA_USE_CURL // diff --git a/common/arg.h b/common/arg.h index 49ab8667b10..33a72b4c4d6 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,3 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); + +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const std::vector & headers); From 4c09328687c400d17baced11302f51dd9cbf6a13 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 12:13:26 +0200 Subject: [PATCH 2/3] support max size and timeout --- common/arg.cpp | 25 +++++++++++++++++-------- common/arg.h | 9 +++++++-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index dc3be7c8d9b..b054b1dd65f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -527,8 +527,7 @@ static bool common_download_model( return true; } -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const std::vector & headers) { +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; std::vector res_buffer; @@ -546,9 +545,14 @@ std::pair> common_remote_get_content(const std::string & #if defined(_WIN32) curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + if (params.timeout > 0) { + curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); + } + if (params.max_size > 0) { + curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); + } http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - for (const auto & header : headers) { + for (const auto & header : params.headers) { http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); @@ -556,13 +560,14 @@ std::pair> common_remote_get_content(const std::string & CURLcode res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); + std::string error_msg = curl_easy_strerror(res); + throw std::runtime_error("error: cannot make GET request : " + error_msg); } long res_code; curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - return { res_code, res_buffer }; + return { res_code, std::move(res_buffer) }; } /** @@ -592,9 +597,13 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ if (!bearer_token.empty()) { headers.push_back("Authorization: Bearer " + bearer_token); } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + // User-Agent header is already set in common_remote_get_content, no need to set it here // make the request - auto res = common_remote_get_content(url, headers); + common_remote_params params; + params.headers = headers; + auto res = common_remote_get_content(url, params); long res_code = res.first; std::string res_str(res.second.data(), res.second.size()); std::string ggufFile; @@ -655,7 +664,7 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s return {}; } -std::pair> common_remote_get_content(const std::string & url, const std::vector & headers) { +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { throw std::runtime_error("error: built without CURL, cannot download model from the internet"); } diff --git a/common/arg.h b/common/arg.h index 33a72b4c4d6..77997c4ef39 100644 --- a/common/arg.h +++ b/common/arg.h @@ -79,5 +79,10 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const std::vector & headers); +struct common_remote_params { + std::vector headers; + long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout + long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB +}; +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); From 617a1ffd61ebd2bff647e0547a6f0b37f9516a15 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 15:11:12 +0200 Subject: [PATCH 3/3] add tests --- common/arg.cpp | 11 ++++++++- common/arg.h | 1 + tests/test-arg-parser.cpp | 47 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index b054b1dd65f..de173159f4a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -162,6 +162,10 @@ struct common_hf_file_res { #ifdef LLAMA_USE_CURL +bool common_has_curl() { + return true; +} + #ifdef __linux__ #include #elif defined(_WIN32) @@ -534,6 +538,7 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { auto data_vec = static_cast *>(data); @@ -561,7 +566,7 @@ std::pair> common_remote_get_content(const std::string & if (res != CURLE_OK) { std::string error_msg = curl_easy_strerror(res); - throw std::runtime_error("error: cannot make GET request : " + error_msg); + throw std::runtime_error("error: cannot make GET request: " + error_msg); } long res_code; @@ -642,6 +647,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ #else +bool common_has_curl() { + return false; +} + static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; diff --git a/common/arg.h b/common/arg.h index 77997c4ef39..70bea100fd4 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,6 +78,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); +bool common_has_curl(); struct common_remote_params { std::vector headers; diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 537fc63a4c9..21dbd540422 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -126,6 +126,53 @@ int main(void) { assert(params.cpuparams.n_threads == 1010); #endif // _WIN32 + if (common_has_curl()) { + printf("test-arg-parser: test curl-related functions\n\n"); + const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md"; + const char * BAD_URL = "https://www.google.com/404"; + const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin"; + + { + printf("test-arg-parser: test good URL\n\n"); + auto res = common_remote_get_content(GOOD_URL, {}); + assert(res.first == 200); + assert(res.second.size() > 0); + std::string str(res.second.data(), res.second.size()); + assert(str.find("llama.cpp") != std::string::npos); + } + + { + printf("test-arg-parser: test bad URL\n\n"); + auto res = common_remote_get_content(BAD_URL, {}); + assert(res.first == 404); + } + + { + printf("test-arg-parser: test max size error\n"); + common_remote_params params; + params.max_size = 1; + try { + common_remote_get_content(GOOD_URL, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); + } + } + + { + printf("test-arg-parser: test timeout error\n"); + common_remote_params params; + params.timeout = 1; + try { + common_remote_get_content(BIG_FILE, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); + } + } + } else { + printf("test-arg-parser: no curl, skipping curl-related functions\n"); + } printf("test-arg-parser: all tests OK\n\n"); }