@@ -127,7 +127,7 @@ struct cmd_params {
127127    std::vector<int > n_prompt;
128128    std::vector<int > n_gen;
129129    std::vector<int > n_batch;
130-     std::vector<bool > f32_kv ;
130+     std::vector<ggml_type> kv_type ;
131131    std::vector<int > n_threads;
132132    std::vector<int > n_gpu_layers;
133133    std::vector<int > main_gpu;
@@ -144,7 +144,7 @@ static const cmd_params cmd_params_defaults = {
144144    /*  n_prompt      */ 512 },
145145    /*  n_gen         */ 128 },
146146    /*  n_batch       */ 512 },
147-     /*  f32_kv          */ false },
147+     /*  kv_type        */ GGML_TYPE_Q8_0 },
148148    /*  n_threads     */ get_num_physical_cores ()},
149149    /*  n_gpu_layers  */ 99 },
150150    /*  main_gpu      */ 0 },
@@ -165,7 +165,16 @@ static void print_usage(int /* argc */, char ** argv) {
165165    printf ("   -p, --n-prompt <n>                (default: %s)\n " join (cmd_params_defaults.n_prompt , " ," c_str ());
166166    printf ("   -n, --n-gen <n>                   (default: %s)\n " join (cmd_params_defaults.n_gen , " ," c_str ());
167167    printf ("   -b, --batch-size <n>              (default: %s)\n " join (cmd_params_defaults.n_batch , " ," c_str ());
168-     printf ("   --memory-f32 <0|1>                (default: %s)\n " join (cmd_params_defaults.f32_kv , " ," c_str ());
168+ 
169+     std::string kv_type_default;
170+     for  (unsigned  int  i = 0 ; i < cmd_params_defaults.kv_type .size (); ++i) {
171+         if  (i > 0 ) {
172+             kv_type_default += " ," 
173+         }
174+         kv_type_default += ggml_type_name (cmd_params_defaults.kv_type [i]);
175+     }
176+     printf ("   -kvt, --kv_type <q8_0|f16|f32>    (default: %s)\n " c_str ());
177+ 
169178    printf ("   -t, --threads <n>                 (default: %s)\n " join (cmd_params_defaults.n_threads , " ," c_str ());
170179    printf ("   -ngl N, --n-gpu-layers <n>        (default: %s)\n " join (cmd_params_defaults.n_gpu_layers , " ," c_str ());
171180    printf ("   -mg i, --main-gpu <n>             (default: %s)\n " join (cmd_params_defaults.main_gpu , " ," c_str ());
@@ -177,7 +186,6 @@ static void print_usage(int /* argc */, char ** argv) {
177186    printf ("   -v, --verbose                     (default: %s)\n " verbose  ? " 1" " 0" 
178187    printf (" \n " 
179188    printf (" Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n " 
180- 
181189}
182190
183191static  cmd_params parse_cmd_params (int  argc, char  ** argv) {
@@ -228,13 +236,32 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
228236            }
229237            auto  p = split<int >(argv[i], split_delim);
230238            params.n_batch .insert (params.n_batch .end (), p.begin (), p.end ());
231-         } else  if  (arg == " --memory-f32 " 
239+         } else  if  (arg == " -kvt "  || arg ==  " --kv-type " 
232240            if  (++i >= argc) {
233241                invalid_param = true ;
234242                break ;
235243            }
236-             auto  p = split<int >(argv[i], split_delim);
237-             params.f32_kv .insert (params.f32_kv .end (), p.begin (), p.end ());
244+             auto  p = split<std::string>(argv[i], split_delim);
245+ 
246+             std::vector<ggml_type> kvt;
247+             for  (const  std::string & type_name : p) {
248+                 if  (type_name == " q8_0" 
249+                     kvt.push_back (GGML_TYPE_Q8_0);
250+                 } else  if  (type_name == " f16" 
251+                     kvt.push_back (GGML_TYPE_F16);
252+                 } else  if  (type_name == " f32" 
253+                     kvt.push_back (GGML_TYPE_F32);
254+                 } else  {
255+                     invalid_param = true ;
256+                     break ;
257+                 }
258+             }
259+             if  (invalid_param) {
260+                 fprintf (stderr, " error: unknown KV type: %s. Known types: Q8_0, F16, F32.\n " 
261+                 break ;
262+             }
263+ 
264+             params.kv_type .insert (params.kv_type .end (), kvt.begin (), kvt.end ());
238265        } else  if  (arg == " -t" " --threads" 
239266            if  (++i >= argc) {
240267                invalid_param = true ;
@@ -332,7 +359,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
332359    if  (params.n_prompt .empty ())     { params.n_prompt  = cmd_params_defaults.n_prompt ; }
333360    if  (params.n_gen .empty ())        { params.n_gen  = cmd_params_defaults.n_gen ; }
334361    if  (params.n_batch .empty ())      { params.n_batch  = cmd_params_defaults.n_batch ; }
335-     if  (params.f32_kv .empty ())        { params.f32_kv  = cmd_params_defaults.f32_kv ; }
362+     if  (params.kv_type .empty ())      { params.kv_type  = cmd_params_defaults.kv_type ; }
336363    if  (params.n_gpu_layers .empty ()) { params.n_gpu_layers  = cmd_params_defaults.n_gpu_layers ; }
337364    if  (params.main_gpu .empty ())     { params.main_gpu  = cmd_params_defaults.main_gpu ; }
338365    if  (params.mul_mat_q .empty ())    { params.mul_mat_q  = cmd_params_defaults.mul_mat_q ; }
@@ -348,7 +375,7 @@ struct cmd_params_instance {
348375    int  n_prompt;
349376    int  n_gen;
350377    int  n_batch;
351-     bool  f32_kv ;
378+     ggml_type kv_type ;
352379    int  n_threads;
353380    int  n_gpu_layers;
354381    int  main_gpu;
@@ -360,7 +387,7 @@ struct cmd_params_instance {
360387        llama_context_params lparams = llama_context_default_params ();
361388        lparams.n_ctx  = n_prompt + n_gen;
362389        lparams.n_batch  = n_batch;
363-         lparams.f16_kv  = !f32_kv ;
390+         lparams.kv_type  = kv_type ;
364391        lparams.n_gpu_layers  = n_gpu_layers;
365392        lparams.main_gpu  = main_gpu;
366393        lparams.mul_mat_q  = mul_mat_q;
@@ -376,7 +403,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
376403
377404    for  (const  auto  & m : params.model )
378405    for  (const  auto  & nb : params.n_batch )
379-     for  (const  auto  & fk  : params.f32_kv )
406+     for  (const  auto  & kvt  : params.kv_type )
380407    for  (const  auto  & nl : params.n_gpu_layers )
381408    for  (const  auto  & mg : params.main_gpu )
382409    for  (const  auto  & mmq : params.mul_mat_q )
@@ -388,7 +415,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
388415            /*  .n_prompt     = */ 
389416            /*  .n_gen        = */ 
390417            /*  .n_batch      = */ 
391-             /*  .f32_kv         = */ fk ,
418+             /*  .kv_type       = */ kvt ,
392419            /*  .n_threads    = */ 
393420            /*  .n_gpu_layers = */ 
394421            /*  .main_gpu     = */ 
@@ -439,7 +466,7 @@ struct test {
439466    uint64_t  model_n_params;
440467    int  n_batch;
441468    int  n_threads;
442-     bool  f32_kv ;
469+     ggml_type kv_type ;
443470    int  n_gpu_layers;
444471    int  main_gpu;
445472    bool  mul_mat_q;
@@ -459,7 +486,7 @@ struct test {
459486        model_n_params = llama_model_n_params (lmodel);
460487        n_batch = inst.n_batch ;
461488        n_threads = inst.n_threads ;
462-         f32_kv  = inst.f32_kv ;
489+         kv_type  = inst.kv_type ;
463490        n_gpu_layers = inst.n_gpu_layers ;
464491        main_gpu = inst.main_gpu ;
465492        mul_mat_q = inst.mul_mat_q ;
@@ -523,7 +550,7 @@ struct test {
523550            " cuda" " opencl" " metal" " gpu_blas" " blas" 
524551            " cpu_info" " gpu_info" 
525552            " model_filename" " model_type" " model_size" " model_n_params" 
526-             " n_batch" " n_threads" " f16_kv " 
553+             " n_batch" " n_threads" " kv_type " 
527554            " n_gpu_layers" " main_gpu" " mul_mat_q" " low_vram" " tensor_split" 
528555            " n_prompt" " n_gen" " test_time" 
529556            " avg_ns" " stddev_ns" 
@@ -543,7 +570,7 @@ struct test {
543570            return  INT;
544571        }
545572        if  (field == " cuda" " opencl" " metal" " gpu_blas" " blas" 
546-             field == " f16_kv "  || field ==  " mul_mat_q" " low_vram" 
573+             field == " mul_mat_q" " low_vram" 
547574            return  BOOL;
548575        }
549576        if  (field == " avg_ts" " stddev_ts" 
@@ -573,7 +600,7 @@ struct test {
573600            std::to_string (cuda), std::to_string (opencl), std::to_string (metal), std::to_string (gpu_blas), std::to_string (blas),
574601            cpu_info, gpu_info,
575602            model_filename, model_type, std::to_string (model_size), std::to_string (model_n_params),
576-             std::to_string (n_batch), std::to_string (n_threads), std::to_string  (!f32_kv ),
603+             std::to_string (n_batch), std::to_string (n_threads), std::string  ( ggml_type_name (kv_type) ),
577604            std::to_string (n_gpu_layers), std::to_string (main_gpu), std::to_string (mul_mat_q), std::to_string (low_vram), tensor_split_str,
578605            std::to_string (n_prompt), std::to_string (n_gen), test_time,
579606            std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
@@ -757,8 +784,8 @@ struct markdown_printer : public printer {
757784        if  (params.n_batch .size () > 1  || params.n_batch  != cmd_params_defaults.n_batch ) {
758785            fields.push_back (" n_batch" 
759786        }
760-         if  (params.f32_kv .size () > 1  || params.f32_kv  != cmd_params_defaults.f32_kv ) {
761-             fields.push_back (" f16_kv " 
787+         if  (params.kv_type .size () > 1  || params.kv_type  != cmd_params_defaults.kv_type ) {
788+             fields.push_back (" kv_type " 
762789        }
763790        if  (params.main_gpu .size () > 1  || params.main_gpu  != cmd_params_defaults.main_gpu ) {
764791            fields.push_back (" main_gpu" 
@@ -826,6 +853,9 @@ struct markdown_printer : public printer {
826853            } else  if  (field == " t/s" 
827854                snprintf (buf, sizeof (buf), " %.2f ± %.2f" avg_ts (), t.stdev_ts ());
828855                value = buf;
856+             } else  if  (field == " kv_type" 
857+                 snprintf (buf, sizeof (buf), " %s" ggml_type_name (t.kv_type ));
858+                 value = buf;
829859            } else  if  (vmap.find (field) != vmap.end ()) {
830860                value = vmap.at (field);
831861            } else  {
0 commit comments