From 60ae92bd5430640609e45afa62931fcaec08dae1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 5 Sep 2024 19:26:21 +0200 Subject: [PATCH] handle env --- common/common.cpp | 113 +++++++++++++++++--------------------- common/common.h | 16 ++++++ tests/test-arg-parser.cpp | 24 ++++++++ 3 files changed, 90 insertions(+), 63 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index ce9199c844254..49db551ae6339 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -77,41 +77,6 @@ using json = nlohmann::ordered_json; -// -// Environment variable utils -// - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::string(value) : target; -} - -template -static typename std::enable_if::value && std::is_integral::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::stoi(value) : target; -} - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - target = value ? std::stof(value) : target; -} - -template -static typename std::enable_if::value, void>::type -get_env(std::string name, T & target) { - char * value = std::getenv(name.c_str()); - if (value) { - std::string val(value); - target = val == "1" || val == "true"; - } -} - // // CPU utils // @@ -390,6 +355,29 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto } } + // handle environment variables + for (auto & opt : options) { + std::string value; + if (opt.get_value_from_env(value)) { + try { + if (opt.handler_void && (value == "1" || value == "true")) { + opt.handler_void(); + } + if (opt.handler_int) { + opt.handler_int(std::stoi(value)); + } + if (opt.handler_string) { + opt.handler_string(value); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(format( + "error while handling environment variable \"%s\": %s\n\n", opt.env.c_str(), e.what())); + } + } + } + + // handle command line arguments auto check_arg = [&](int i) { if (i+1 >= argc) { throw std::invalid_argument("expected value for argument"); @@ -405,6 +393,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str())); } auto opt = *arg_to_options[arg]; + if (opt.has_value_from_env()) { + fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env.c_str(), arg.c_str()); + } try { if (opt.handler_void) { opt.handler_void(); @@ -449,10 +440,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto gpt_params_handle_model_default(params); - if (params.hf_token.empty()) { - get_env("HF_TOKEN", params.hf_token); - } - if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -762,7 +749,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example params.cpuparams.n_threads = std::thread::hardware_concurrency(); } } - )); + ).set_env("LLAMA_ARG_THREADS")); add_opt(llama_arg( {"-tb", "--threads-batch"}, "N", "number of threads to use during batch and prompt processing (default: same as --threads)", @@ -960,28 +947,28 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](int value) { params.n_ctx = value; } - )); + ).set_env("LLAMA_ARG_CTX_SIZE")); add_opt(llama_arg( {"-n", "--predict"}, "N", format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict), [¶ms](int value) { params.n_predict = value; } - )); + ).set_env("LLAMA_ARG_N_PREDICT")); add_opt(llama_arg( {"-b", "--batch-size"}, "N", format("logical maximum batch size (default: %d)", params.n_batch), [¶ms](int value) { params.n_batch = value; } - )); + ).set_env("LLAMA_ARG_BATCH")); add_opt(llama_arg( {"-ub", "--ubatch-size"}, "N", format("physical maximum batch size (default: %d)", params.n_ubatch), [¶ms](int value) { params.n_ubatch = value; } - )); + ).set_env("LLAMA_ARG_UBATCH")); add_opt(llama_arg( {"--keep"}, "N", format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep), @@ -1002,7 +989,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms]() { params.flash_attn = true; } - )); + ).set_env("LLAMA_ARG_FLASH_ATTN")); add_opt(llama_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with\n", @@ -1599,7 +1586,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](std::string value) { params.defrag_thold = std::stof(value); } - )); + ).set_env("LLAMA_ARG_DEFRAG_THOLD")); add_opt(llama_arg( {"-np", "--parallel"}, "N", format("number of parallel sequences to decode (default: %d)", params.n_parallel), @@ -1620,14 +1607,14 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms]() { params.cont_batching = true; } - )); + ).set_env("LLAMA_ARG_CONT_BATCHING")); add_opt(llama_arg( {"-nocb", "--no-cont-batching"}, "disable continuous batching", [¶ms]() { params.cont_batching = false; } - )); + ).set_env("LLAMA_ARG_NO_CONT_BATCHING")); add_opt(llama_arg( {"--mmproj"}, "FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md", @@ -1688,7 +1675,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } } - )); + ).set_env("LLAMA_ARG_N_GPU_LAYERS")); add_opt(llama_arg( {"-ngld", "--gpu-layers-draft"}, "N", "number of layers to store in VRAM for the draft model", @@ -1830,7 +1817,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](std::string value) { params.model = value; } - ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); add_opt(llama_arg( {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", @@ -1844,28 +1831,28 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](std::string value) { params.model_url = value; } - )); + ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(llama_arg( {"-hfr", "--hf-repo"}, "REPO", "Hugging Face model repository (default: unused)", [¶ms](std::string value) { params.hf_repo = value; } - )); + ).set_env("LLAMA_ARG_HF_REPO")); add_opt(llama_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file (default: unused)", [¶ms](std::string value) { params.hf_file = value; } - )); + ).set_env("LLAMA_ARG_HF_FILE")); add_opt(llama_arg( {"-hft", "--hf-token"}, "TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)", [¶ms](std::string value) { params.hf_token = value; } - )); + ).set_env("HF_TOKEN")); add_opt(llama_arg( {"--context-file"}, "FNAME", "file to load context from (repeat to specify multiple files)", @@ -2012,14 +1999,14 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](std::string value) { params.hostname = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST")); add_opt(llama_arg( {"--port"}, "PORT", format("port to listen (default: %d)", params.port), [¶ms](int value) { params.port = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT")); add_opt(llama_arg( {"--path"}, "PATH", format("path to serve static files from (default: %s)", params.public_path.c_str()), @@ -2028,19 +2015,19 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example } ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( - {"--embedding(s)"}, + {"--embedding", "--embeddings"}, format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), [¶ms]() { params.embedding = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(llama_arg( {"--api-key"}, "KEY", "API key to use for authentication (default: none)", [¶ms](std::string value) { params.api_keys.push_back(value); } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(llama_arg( {"--api-key-file"}, "FNAME", "path to file containing API keys (default: none)", @@ -2086,7 +2073,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms](int value) { params.n_threads_http = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); add_opt(llama_arg( {"-spf", "--system-prompt-file"}, "FNAME", "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications", @@ -2123,14 +2110,14 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example [¶ms]() { params.endpoint_metrics = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS")); add_opt(llama_arg( {"--no-slots"}, format("disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"), [¶ms]() { params.endpoint_slots = false; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS")); add_opt(llama_arg( {"--slot-save-path"}, "PATH", "path to save slot kv cache (default: disabled)", @@ -2157,7 +2144,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example } params.chat_template = value; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); add_opt(llama_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.h b/common/common.h index 05211bf972764..c6f476ec34586 100644 --- a/common/common.h +++ b/common/common.h @@ -316,6 +316,7 @@ struct llama_arg { llama_arg(std::vector args, std::string help, std::function handler) : args(args), help(help), handler_void(handler) {} // support 2 values for arg + // note: env variable is not yet support for 2 values llama_arg(std::vector args, std::string value_hint, std::string value_hint_2, std::string help, std::function handler) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} llama_arg & set_examples(std::set examples) { @@ -324,6 +325,7 @@ struct llama_arg { } llama_arg & set_env(std::string env) { + help = help + "\n(env: " + env + ")"; this->env = std::move(env); return *this; } @@ -332,6 +334,20 @@ struct llama_arg { return examples.find(ex) != examples.end(); } + bool get_value_from_env(std::string & output) { + if (env.empty()) return false; + char * value = std::getenv(env.c_str()); + if (value) { + output = value; + return true; + } + return false; + } + + bool has_value_from_env() { + return std::getenv(env.c_str()); + } + std::string to_string(bool markdown); }; diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 8b95a59d39c86..ff1a626c39761 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -63,5 +63,29 @@ int main(void) { assert(params.n_predict == 6789); assert(params.n_batch == 9090); + printf("test-arg-parser: test environment variables (valid + invalid usages)\n\n"); + + setenv("LLAMA_ARG_THREADS", "blah", true); + argv = {"binary_name"}; + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + + setenv("LLAMA_ARG_MODEL", "blah.gguf", true); + setenv("LLAMA_ARG_THREADS", "1010", true); + argv = {"binary_name"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.model == "blah.gguf"); + assert(params.cpuparams.n_threads == 1010); + + + printf("test-arg-parser: test environment variables being overwritten\n\n"); + + setenv("LLAMA_ARG_MODEL", "blah.gguf", true); + setenv("LLAMA_ARG_THREADS", "1010", true); + argv = {"binary_name", "-m", "overwritten.gguf"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.model == "overwritten.gguf"); + assert(params.cpuparams.n_threads == 1010); + + printf("test-arg-parser: all tests OK\n\n"); }