@@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
305305                break ;
306306            }
307307            params.n_keep  = std::stoi (argv[i]);
308+         } else  if  (arg == " --draft" 
309+             if  (++i >= argc) {
310+                 invalid_param = true ;
311+                 break ;
312+             }
313+             params.n_draft  = std::stoi (argv[i]);
308314        } else  if  (arg == " --chunks" 
309315            if  (++i >= argc) {
310316                invalid_param = true ;
@@ -317,6 +323,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317323                break ;
318324            }
319325            params.model  = argv[i];
326+         } else  if  (arg == " -md" " --model-draft" 
327+             if  (++i >= argc) {
328+                 invalid_param = true ;
329+                 break ;
330+             }
331+             params.model_draft  = argv[i];
320332        } else  if  (arg == " -a" " --alias" 
321333            if  (++i >= argc) {
322334                invalid_param = true ;
@@ -638,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
638650    fprintf (stdout, "   --hellaswag           compute HellaSwag score over random tasks from datafile supplied with -f\n " 
639651    fprintf (stdout, "   --hellaswag-tasks N   number of tasks to use when computing the HellaSwag score (default: %zu)\n " hellaswag_tasks );
640652    fprintf (stdout, "   --keep N              number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " n_keep );
653+     fprintf (stdout, "   --draft N             number of tokens to draft for speculative decoding (default: %d)\n " n_draft );
641654    fprintf (stdout, "   --chunks N            max number of chunks to process (default: %d, -1 = all)\n " n_chunks );
642655    if  (llama_mlock_supported ()) {
643656        fprintf (stdout, "   --mlock               force system to keep model in RAM rather than swapping or compressing\n " 
@@ -669,6 +682,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669682    fprintf (stdout, "   --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n " 
670683    fprintf (stdout, "   -m FNAME, --model FNAME\n " 
671684    fprintf (stdout, "                         model path (default: %s)\n " model .c_str ());
685+     fprintf (stdout, "   -md FNAME, --model-draft FNAME\n " 
686+     fprintf (stdout, "                         draft model for speculative decoding (default: %s)\n " model .c_str ());
672687    fprintf (stdout, "   -ld LOGDIR, --logdir LOGDIR\n " 
673688    fprintf (stdout, "                         path under which to save YAML logs (no logging if unset)\n " 
674689    fprintf (stdout, " \n " 
@@ -832,6 +847,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
832847    return  result;
833848}
834849
850+ // 
851+ //  Sampling utils
852+ // 
853+ 
854+ llama_token llama_sample_token (
855+                   struct  llama_context  * ctx,
856+                   struct  llama_context  * ctx_guidance,
857+                   struct  llama_grammar  * grammar,
858+                const  struct  gpt_params  & params,
859+         const  std::vector<llama_token> & last_tokens,
860+          std::vector<llama_token_data> & candidates,
861+                                    int    idx) {
862+     const  int  n_ctx   = llama_n_ctx (ctx);
863+     const  int  n_vocab = llama_n_vocab (ctx);
864+ 
865+     const  float    temp            = params.temp ;
866+     const  int32_t  top_k           = params.top_k  <= 0  ? n_vocab : params.top_k ;
867+     const  float    top_p           = params.top_p ;
868+     const  float    tfs_z           = params.tfs_z ;
869+     const  float    typical_p       = params.typical_p ;
870+     const  int32_t  repeat_last_n   = params.repeat_last_n  < 0  ? n_ctx : params.repeat_last_n ;
871+     const  float    repeat_penalty  = params.repeat_penalty ;
872+     const  float    alpha_presence  = params.presence_penalty ;
873+     const  float    alpha_frequency = params.frequency_penalty ;
874+     const  int      mirostat        = params.mirostat ;
875+     const  float    mirostat_tau    = params.mirostat_tau ;
876+     const  float    mirostat_eta    = params.mirostat_eta ;
877+     const  bool     penalize_nl     = params.penalize_nl ;
878+ 
879+     llama_token id = 0 ;
880+ 
881+     float  * logits = llama_get_logits (ctx) + idx * n_vocab;
882+ 
883+     //  Apply params.logit_bias map
884+     for  (auto  it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
885+         logits[it->first ] += it->second ;
886+     }
887+ 
888+     candidates.clear ();
889+     for  (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
890+         candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
891+     }
892+ 
893+     llama_token_data_array cur_p = { candidates.data (), candidates.size (), false  };
894+ 
895+     if  (ctx_guidance) {
896+         llama_sample_classifier_free_guidance (ctx, &cur_p, ctx_guidance, params.cfg_scale );
897+     }
898+ 
899+     //  apply penalties
900+     if  (!last_tokens.empty ()) {
901+         const  float  nl_logit = logits[llama_token_nl (ctx)];
902+         const  int  last_n_repeat = std::min (std::min ((int )last_tokens.size (), repeat_last_n), n_ctx);
903+ 
904+         llama_sample_repetition_penalty (ctx, &cur_p,
905+                 last_tokens.data () + last_tokens.size () - last_n_repeat,
906+                 last_n_repeat, repeat_penalty);
907+         llama_sample_frequency_and_presence_penalties (ctx, &cur_p,
908+                 last_tokens.data () + last_tokens.size () - last_n_repeat,
909+                 last_n_repeat, alpha_frequency, alpha_presence);
910+ 
911+         if  (!penalize_nl) {
912+             for  (size_t  idx = 0 ; idx < cur_p.size ; idx++) {
913+                 if  (cur_p.data [idx].id  == llama_token_nl (ctx)) {
914+                     cur_p.data [idx].logit  = nl_logit;
915+                     break ;
916+                 }
917+             }
918+         }
919+     }
920+ 
921+     if  (grammar != NULL ) {
922+         llama_sample_grammar (ctx, &cur_p, grammar);
923+     }
924+ 
925+     if  (temp <= 0 ) {
926+         //  Greedy sampling
927+         id = llama_sample_token_greedy (ctx, &cur_p);
928+     } else  {
929+         if  (mirostat == 1 ) {
930+             static  float  mirostat_mu = 2 .0f  * mirostat_tau;
931+             const  int  mirostat_m = 100 ;
932+             llama_sample_temperature (ctx, &cur_p, temp);
933+             id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
934+         } else  if  (mirostat == 2 ) {
935+             static  float  mirostat_mu = 2 .0f  * mirostat_tau;
936+             llama_sample_temperature (ctx, &cur_p, temp);
937+             id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
938+         } else  {
939+             //  Temperature sampling
940+             llama_sample_top_k       (ctx, &cur_p, top_k, 1 );
941+             llama_sample_tail_free   (ctx, &cur_p, tfs_z, 1 );
942+             llama_sample_typical     (ctx, &cur_p, typical_p, 1 );
943+             llama_sample_top_p       (ctx, &cur_p, top_p, 1 );
944+             llama_sample_temperature (ctx, &cur_p, temp);
945+ 
946+             {
947+                 const  int  n_top = 10 ;
948+                 LOG (" top %d candidates:\n " 
949+ 
950+                 for  (int  i = 0 ; i < n_top; i++) {
951+                     const  llama_token id = cur_p.data [i].id ;
952+                     LOG ("  - %5d: '%12s' (%.3f)\n " llama_token_to_piece (ctx, id).c_str (), cur_p.data [i].p );
953+                 }
954+             }
955+ 
956+             id = llama_sample_token (ctx, &cur_p);
957+ 
958+             LOG (" sampled token: %5d: '%s'\n " llama_token_to_piece (ctx, id).c_str ());
959+         }
960+     }
961+     //  printf("`%d`", candidates_p.size);
962+ 
963+     if  (grammar != NULL ) {
964+         llama_grammar_accept_token (ctx, grammar, id);
965+     }
966+ 
967+     return  id;
968+ }
969+ 
970+ // 
971+ //  YAML utils
972+ // 
973+ 
835974//  returns true if successful, false otherwise
836975bool  create_directory_with_parents (const  std::string & path) {
837976#ifdef  _WIN32
@@ -1070,6 +1209,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10701209    fprintf (stream, " mirostat_lr: %f # default: 0.1\n " mirostat_eta );
10711210    fprintf (stream, " mlock: %s # default: false\n " use_mlock  ? " true" " false" 
10721211    fprintf (stream, " model: %s # default: models/7B/ggml-model.bin\n " model .c_str ());
1212+     fprintf (stream, " model_draft: %s # default:\n " model_draft .c_str ());
10731213    fprintf (stream, " mtest: %s # default: false\n " mem_test  ? " true" " false" 
10741214    fprintf (stream, " multiline_input: %s # default: false\n " multiline_input  ? " true" " false" 
10751215    fprintf (stream, " n_gpu_layers: %d # default: 0\n " n_gpu_layers );
0 commit comments