Skip to content

Commit

Permalink
handle env
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Sep 5, 2024
1 parent 753782a commit 60ae92b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 63 deletions.
113 changes: 50 additions & 63 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,6 @@

using json = nlohmann::ordered_json;

//
// Environment variable utils
//

template<typename T>
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::string(value) : target;
}

template<typename T>
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stoi(value) : target;
}

template<typename T>
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stof(value) : target;
}

template<typename T>
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
if (value) {
std::string val(value);
target = val == "1" || val == "true";
}
}

//
// CPU utils
//
Expand Down Expand Up @@ -390,6 +355,29 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto
}
}

// handle environment variables
for (auto & opt : options) {
std::string value;
if (opt.get_value_from_env(value)) {
try {
if (opt.handler_void && (value == "1" || value == "true")) {
opt.handler_void();
}
if (opt.handler_int) {
opt.handler_int(std::stoi(value));
}
if (opt.handler_string) {
opt.handler_string(value);
continue;
}
} catch (std::exception & e) {
throw std::invalid_argument(format(
"error while handling environment variable \"%s\": %s\n\n", opt.env.c_str(), e.what()));
}
}
}

// handle command line arguments
auto check_arg = [&](int i) {
if (i+1 >= argc) {
throw std::invalid_argument("expected value for argument");
Expand All @@ -405,6 +393,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto
throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str()));
}
auto opt = *arg_to_options[arg];
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env.c_str(), arg.c_str());
}
try {
if (opt.handler_void) {
opt.handler_void();
Expand Down Expand Up @@ -449,10 +440,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto

gpt_params_handle_model_default(params);

if (params.hf_token.empty()) {
get_env("HF_TOKEN", params.hf_token);
}

if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
Expand Down Expand Up @@ -762,7 +749,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.cpuparams.n_threads = std::thread::hardware_concurrency();
}
}
));
).set_env("LLAMA_ARG_THREADS"));
add_opt(llama_arg(
{"-tb", "--threads-batch"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads)",
Expand Down Expand Up @@ -960,28 +947,28 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](int value) {
params.n_ctx = value;
}
));
).set_env("LLAMA_ARG_CTX_SIZE"));
add_opt(llama_arg(
{"-n", "--predict"}, "N",
format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict),
[&params](int value) {
params.n_predict = value;
}
));
).set_env("LLAMA_ARG_N_PREDICT"));
add_opt(llama_arg(
{"-b", "--batch-size"}, "N",
format("logical maximum batch size (default: %d)", params.n_batch),
[&params](int value) {
params.n_batch = value;
}
));
).set_env("LLAMA_ARG_BATCH"));
add_opt(llama_arg(
{"-ub", "--ubatch-size"}, "N",
format("physical maximum batch size (default: %d)", params.n_ubatch),
[&params](int value) {
params.n_ubatch = value;
}
));
).set_env("LLAMA_ARG_UBATCH"));
add_opt(llama_arg(
{"--keep"}, "N",
format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep),
Expand All @@ -1002,7 +989,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() {
params.flash_attn = true;
}
));
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(llama_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with\n",
Expand Down Expand Up @@ -1599,7 +1586,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) {
params.defrag_thold = std::stof(value);
}
));
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
add_opt(llama_arg(
{"-np", "--parallel"}, "N",
format("number of parallel sequences to decode (default: %d)", params.n_parallel),
Expand All @@ -1620,14 +1607,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() {
params.cont_batching = true;
}
));
).set_env("LLAMA_ARG_CONT_BATCHING"));
add_opt(llama_arg(
{"-nocb", "--no-cont-batching"},
"disable continuous batching",
[&params]() {
params.cont_batching = false;
}
));
).set_env("LLAMA_ARG_NO_CONT_BATCHING"));
add_opt(llama_arg(
{"--mmproj"}, "FILE",
"path to a multimodal projector file for LLaVA. see examples/llava/README.md",
Expand Down Expand Up @@ -1688,7 +1675,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
}
));
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(llama_arg(
{"-ngld", "--gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
Expand Down Expand Up @@ -1830,7 +1817,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) {
params.model = value;
}
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
add_opt(llama_arg(
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
Expand All @@ -1844,28 +1831,28 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) {
params.model_url = value;
}
));
).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(llama_arg(
{"-hfr", "--hf-repo"}, "REPO",
"Hugging Face model repository (default: unused)",
[&params](std::string value) {
params.hf_repo = value;
}
));
).set_env("LLAMA_ARG_HF_REPO"));
add_opt(llama_arg(
{"-hff", "--hf-file"}, "FILE",
"Hugging Face model file (default: unused)",
[&params](std::string value) {
params.hf_file = value;
}
));
).set_env("LLAMA_ARG_HF_FILE"));
add_opt(llama_arg(
{"-hft", "--hf-token"}, "TOKEN",
"Hugging Face access token (default: value from HF_TOKEN environment variable)",
[&params](std::string value) {
params.hf_token = value;
}
));
).set_env("HF_TOKEN"));
add_opt(llama_arg(
{"--context-file"}, "FNAME",
"file to load context from (repeat to specify multiple files)",
Expand Down Expand Up @@ -2012,14 +1999,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) {
params.hostname = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST"));
add_opt(llama_arg(
{"--port"}, "PORT",
format("port to listen (default: %d)", params.port),
[&params](int value) {
params.port = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
add_opt(llama_arg(
{"--path"}, "PATH",
format("path to serve static files from (default: %s)", params.public_path.c_str()),
Expand All @@ -2028,19 +2015,19 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--embedding(s)"},
{"--embedding", "--embeddings"},
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
[&params]() {
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(llama_arg(
{"--api-key"}, "KEY",
"API key to use for authentication (default: none)",
[&params](std::string value) {
params.api_keys.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
add_opt(llama_arg(
{"--api-key-file"}, "FNAME",
"path to file containing API keys (default: none)",
Expand Down Expand Up @@ -2086,7 +2073,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](int value) {
params.n_threads_http = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
add_opt(llama_arg(
{"-spf", "--system-prompt-file"}, "FNAME",
"set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications",
Expand Down Expand Up @@ -2123,14 +2110,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() {
params.endpoint_metrics = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
add_opt(llama_arg(
{"--no-slots"},
format("disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[&params]() {
params.endpoint_slots = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS"));
add_opt(llama_arg(
{"--slot-save-path"}, "PATH",
"path to save slot kv cache (default: disabled)",
Expand All @@ -2157,7 +2144,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(llama_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
Expand Down
16 changes: 16 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ struct llama_arg {
llama_arg(std::vector<std::string> args, std::string help, std::function<void(void)> handler) : args(args), help(help), handler_void(handler) {}

// support 2 values for arg
// note: env variable is not yet support for 2 values
llama_arg(std::vector<std::string> args, std::string value_hint, std::string value_hint_2, std::string help, std::function<void(std::string, std::string)> handler) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}

llama_arg & set_examples(std::set<enum llama_example> examples) {
Expand All @@ -324,6 +325,7 @@ struct llama_arg {
}

llama_arg & set_env(std::string env) {
help = help + "\n(env: " + env + ")";
this->env = std::move(env);
return *this;
}
Expand All @@ -332,6 +334,20 @@ struct llama_arg {
return examples.find(ex) != examples.end();
}

bool get_value_from_env(std::string & output) {
if (env.empty()) return false;
char * value = std::getenv(env.c_str());
if (value) {
output = value;
return true;
}
return false;
}

bool has_value_from_env() {
return std::getenv(env.c_str());
}

std::string to_string(bool markdown);
};

Expand Down
24 changes: 24 additions & 0 deletions tests/test-arg-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,29 @@ int main(void) {
assert(params.n_predict == 6789);
assert(params.n_batch == 9090);

printf("test-arg-parser: test environment variables (valid + invalid usages)\n\n");

setenv("LLAMA_ARG_THREADS", "blah", true);
argv = {"binary_name"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));

setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
assert(params.model == "blah.gguf");
assert(params.cpuparams.n_threads == 1010);


printf("test-arg-parser: test environment variables being overwritten\n\n");

setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name", "-m", "overwritten.gguf"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
assert(params.model == "overwritten.gguf");
assert(params.cpuparams.n_threads == 1010);


printf("test-arg-parser: all tests OK\n\n");
}

0 comments on commit 60ae92b

Please sign in to comment.