diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93f99929880f6..8b96a7a6292cd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2012,6 +2012,10 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --grp-attn-n N\n"); + printf(" group-attention factor (default: %d)\n", params.grp_attn_n); + printf(" --grp-attn-w N\n"); + printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -2236,6 +2240,24 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.yarn_beta_slow = std::stof(argv[i]); } + else if (arg == "--grp-attn-n") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.grp_attn_n = std::stoi(argv[i]); + } + else if (arg == "--grp-attn-w") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.grp_attn_w = std::stoi(argv[i]); + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc)