@@ -26,6 +26,17 @@ struct server_params {
2626    int32_t  write_timeout = 600 ;
2727};
2828
29+ //  completion token output with probabilities
30+ struct  completion_token_output  {
31+     struct  token_prob  {
32+         llama_token tok;
33+         float  prob;
34+     };
35+ 
36+     std::vector<token_prob> probs;
37+     llama_token tok;
38+ };
39+ 
2940static  size_t  common_part (const  std::vector<llama_token> & a, const  std::vector<llama_token> & b) {
3041    size_t  i;
3142    for  (i = 0 ; i < a.size () && i < b.size () && a[i] == b[i]; i++) {}
@@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
8697    fflush (stdout);
8798}
8899
100+ //  format incomplete utf-8 multibyte character for output
101+ static  std::string tokens_to_output_formatted_string (const  llama_context * ctx, const  llama_token token) {
102+     std::string out = token == -1  ? " "   : llama_token_to_str (ctx, token);
103+     //  if first bit is 1, meaning it's a partial character
104+     if  (out.size () > 0  && (out[0 ] & 0x80 ) == 0x80 ) {
105+         std::stringstream ss;
106+         ss<< std::hex << (out[0 ] & 0xff );
107+         std::string res  ( ss.str () );
108+         out = " byte: \\ x"   + res;
109+     }
110+     return  out;
111+ }
112+ 
113+ //  convert a vector of completion_token_output to json
114+ static  json probs_vector_to_json (const  llama_context * ctx, const  std::vector<completion_token_output> probs) {
115+     json out = json::array ();
116+     for  (const  auto  & prob : probs) {
117+         json probs_for_token = json::array ();
118+         for  (const  auto  & p : prob.probs ) {
119+             std::string tok_str = tokens_to_output_formatted_string (ctx, p.tok );
120+             probs_for_token.push_back (json {
121+                 { " tok_str"  , tok_str },
122+                 { " prob"  , p.prob  },
123+             });
124+         }
125+         std::string tok_str = tokens_to_output_formatted_string (ctx, prob.tok );
126+         out.push_back (json {
127+             {" content"  , tok_str},
128+             {" probs"  , probs_for_token},
129+         });
130+     }
131+     return  out;
132+ }
133+ 
89134static  bool  server_verbose = false ;
90135
91136#if  SERVER_VERBOSE != 1
@@ -107,6 +152,7 @@ struct llama_server_context {
107152    bool  stream = false ;
108153    bool  has_next_token = false ;
109154    std::string generated_text;
155+     std::vector<completion_token_output> generated_token_probs;
110156
111157    size_t  num_tokens_predicted = 0 ;
112158    size_t  n_past = 0 ;
@@ -142,6 +188,7 @@ struct llama_server_context {
142188        num_tokens_predicted = 0 ;
143189        generated_text = " "  ;
144190        generated_text.reserve (params.n_ctx );
191+         generated_token_probs.clear ();
145192        truncated = false ;
146193        stopped_eos = false ;
147194        stopped_word = false ;
@@ -221,8 +268,9 @@ struct llama_server_context {
221268        llama_set_rng_seed (ctx, params.seed );
222269    }
223270
224-     llama_token nextToken () {
225-         llama_token result = -1 ;
271+     completion_token_output nextToken () {
272+         completion_token_output result;
273+         result.tok  = -1 ;
226274
227275        if  (embd.size () >= (size_t )params.n_ctx ) {
228276            //  Reset context
@@ -261,7 +309,8 @@ struct llama_server_context {
261309
262310        if  (params.n_predict  == 0 ) {
263311            has_next_token = false ;
264-             return  llama_token_eos ();
312+             result.tok  = llama_token_eos ();
313+             return  result;
265314        }
266315
267316        //  out of user input, sample next token
@@ -278,7 +327,7 @@ struct llama_server_context {
278327        const  float  mirostat_tau = params.mirostat_tau ;
279328        const  float  mirostat_eta = params.mirostat_eta ;
280329        const  bool  penalize_nl = params.penalize_nl ;
281-         llama_token id =  0 ;
330+         const   int32_t  n_probs = params. n_probs ;
282331
283332        {
284333            auto  * logits = llama_get_logits (ctx);
@@ -312,35 +361,42 @@ struct llama_server_context {
312361
313362            if  (temp <= 0 ) {
314363                //  Greedy sampling
315-                 id = llama_sample_token_greedy (ctx, &candidates_p);
364+                 result.tok  = llama_sample_token_greedy (ctx, &candidates_p);
365+                 if  (n_probs > 0 ) {
366+                     llama_sample_softmax (ctx, &candidates_p);
367+                 }
316368            } else  {
317369                if  (mirostat == 1 ) {
318370                    static  float  mirostat_mu = 2 .0f  * mirostat_tau;
319371                    const  int  mirostat_m = 100 ;
320372                    llama_sample_temperature (ctx, &candidates_p, temp);
321-                     id  = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
373+                     result. tok  = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
322374                } else  if  (mirostat == 2 ) {
323375                    static  float  mirostat_mu = 2 .0f  * mirostat_tau;
324376                    llama_sample_temperature (ctx, &candidates_p, temp);
325-                     id  = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
377+                     result. tok  = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
326378                } else  {
327379                    //  Temperature sampling
328-                     llama_sample_top_k (ctx, &candidates_p, top_k, 1 );
329-                     llama_sample_tail_free (ctx, &candidates_p, tfs_z, 1 );
330-                     llama_sample_typical (ctx, &candidates_p, typical_p, 1 );
331-                     llama_sample_top_p (ctx, &candidates_p, top_p, 1 );
380+                     size_t  min_keep = std::max (1 , n_probs);
381+                     llama_sample_top_k (ctx, &candidates_p, top_k, min_keep);
382+                     llama_sample_tail_free (ctx, &candidates_p, tfs_z, min_keep);
383+                     llama_sample_typical (ctx, &candidates_p, typical_p, min_keep);
384+                     llama_sample_top_p (ctx, &candidates_p, top_p, min_keep);
332385                    llama_sample_temperature (ctx, &candidates_p, temp);
333-                     id  = llama_sample_token (ctx, &candidates_p);
386+                     result. tok  = llama_sample_token (ctx, &candidates_p);
334387                }
335388            }
389+ 
390+             for  (size_t  i = 0 ; i < std::min (candidates_p.size , (size_t ) n_probs); ++i) {
391+                 result.probs .push_back ({candidates_p.data [i].id , candidates_p.data [i].p });
392+             }
336393            last_n_tokens.erase (last_n_tokens.begin ());
337-             last_n_tokens.push_back (id );
394+             last_n_tokens.push_back (result. tok );
338395            num_tokens_predicted++;
339396        }
340397
341398        //  add it to the context
342-         embd.push_back (id);
343-         result = id;
399+         embd.push_back (result.tok );
344400        //  decrement remaining sampling budget
345401        --n_remain;
346402
@@ -382,12 +438,16 @@ struct llama_server_context {
382438        return  stop_pos;
383439    }
384440
385-     std::string  doCompletion () {
386-         const  llama_token token  = nextToken ();
441+     completion_token_output  doCompletion () {
442+         const  completion_token_output token_with_probs  = nextToken ();
387443
388-         const  std::string token_text = token  == -1  ? " "   : llama_token_to_str (ctx, token );
444+         const  std::string token_text = token_with_probs. tok  == -1  ? " "   : llama_token_to_str (ctx, token_with_probs. tok );
389445        generated_text += token_text;
390446
447+         if  (params.n_probs  > 0 ) {
448+             generated_token_probs.push_back (token_with_probs);
449+         }
450+ 
391451        if  (multibyte_pending > 0 ) {
392452            multibyte_pending -= token_text.size ();
393453        } else  if  (token_text.size () == 1 ) {
@@ -416,8 +476,8 @@ struct llama_server_context {
416476        }
417477
418478        LOG_VERBOSE (" next token"  , {
419-             { " token"  , token  },
420-             { " token_text"  , llama_token_to_str (ctx, token ) },
479+             { " token"  , token_with_probs. tok  },
480+             { " token_text"  , tokens_to_output_formatted_string (ctx, token_with_probs. tok ) },
421481            { " has_next_token"  , has_next_token },
422482            { " n_remain"  , n_remain },
423483            { " num_tokens_predicted"  , num_tokens_predicted },
@@ -427,7 +487,7 @@ struct llama_server_context {
427487            { " stopping_word"  , stopping_word },
428488        });
429489
430-         return  token_text ;
490+         return  token_with_probs ;
431491    }
432492
433493    std::vector<float > getEmbedding () {
@@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
669729        { " ignore_eos"  , ignore_eos },
670730        { " stream"  , llama.stream  },
671731        { " logit_bias"  , llama.params .logit_bias  },
732+         { " n_probs"  , llama.params .n_probs  },
672733    };
673734}
674735
@@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
678739    };
679740}
680741
681- static  json format_final_response (llama_server_context & llama, const  std::string & content) {
682-     return  json {
742+ static  json format_final_response (llama_server_context & llama, const  std::string & content, const  std::vector<completion_token_output> & probs) {
743+ 
744+     json res = json {
683745        { " content"  , content },
684746        { " stop"  , true  },
685747        { " model"  , llama.params .model_alias  },
@@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
692754        { " stopped_limit"  , llama.stopped_limit  },
693755        { " stopping_word"  , llama.stopping_word  },
694756    };
757+ 
758+     if  (llama.params .n_probs  > 0 ) {
759+         res[" completion_probabilities"  ] = probs_vector_to_json (llama.ctx , probs);
760+     }
761+ 
762+     return  res;
695763}
696764
697- static  json format_partial_response (const  std::string & content) {
698-     return  json {
765+ static  json format_partial_response (llama_server_context & llama,  const  std::string & content,  const  std::vector<completion_token_output> & probs ) {
766+     json res =  json {
699767        { " content"  , content },
700768        { " stop"  , false  },
701769    };
770+ 
771+     if  (llama.params .n_probs  > 0 ) {
772+         res[" completion_probabilities"  ] = probs_vector_to_json (llama.ctx , probs);
773+     }
774+ 
775+     return  res;
702776}
703777
704778static  json format_tokenizer_response (const  std::vector<llama_token> & tokens) {
@@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
728802    llama.params .n_keep  = body.value (" n_keep"  , default_params.n_keep );
729803    llama.params .seed  = body.value (" seed"  , default_params.seed );
730804    llama.params .prompt  = body.value (" prompt"  , default_params.prompt );
805+     llama.params .n_probs  = body.value (" n_probs"  , default_params.n_probs );
731806
732807    llama.params .logit_bias .clear ();
733808    if  (body.value (" ignore_eos"  , false )) {
@@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
830905            size_t  stop_pos = std::string::npos;
831906
832907            while  (llama.has_next_token ) {
833-                 const  std::string token_text = llama.doCompletion ();
908+                 const  completion_token_output token_with_probs = llama.doCompletion ();
909+                 const  std::string token_text = llama_token_to_str (llama.ctx , token_with_probs.tok );
834910
835911                stop_pos = llama.findStoppingStrings (llama.generated_text ,
836912                    token_text.size (), STOP_FULL);
@@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
844920                    llama.generated_text .end ());
845921            }
846922
847-             const  json data = format_final_response (llama, llama.generated_text );
923+             const  json data = format_final_response (llama, llama.generated_text , llama. generated_token_probs );
848924
849925            llama_print_timings (llama.ctx );
850926
@@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
853929        } else  {
854930            const  auto  chunked_content_provider = [&](size_t , DataSink & sink) {
855931                size_t  sent_count = 0 ;
932+                 size_t  sent_token_probs_index = 0 ;
856933
857934                while  (llama.has_next_token ) {
858-                     const  std::string token_text = llama.doCompletion ();
935+                     const  completion_token_output token_with_probs = llama.doCompletion ();
936+                     const  std::string token_text = llama_token_to_str (llama.ctx , token_with_probs.tok );
859937                    if  (llama.multibyte_pending  > 0 ) {
860938                        continue ;
861939                    }
@@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
878956                    const  std::string to_send = llama.generated_text .substr (pos, stop_pos);
879957                    sent_count += to_send.size ();
880958
959+                     std::vector<completion_token_output> probs_output = {};
960+ 
961+                     if  (llama.params .n_probs  > 0 ) {
962+                         const  std::vector<llama_token> to_send_toks = llama_tokenize (llama.ctx , to_send, false );
963+                         size_t  probs_pos = std::min (sent_token_probs_index, llama.generated_token_probs .size ());
964+                         size_t  probs_stop_pos = std::min (sent_token_probs_index + to_send_toks.size (), llama.generated_token_probs .size ());
965+                         if  (probs_pos < probs_stop_pos) {
966+                             probs_output = std::vector<completion_token_output>(llama.generated_token_probs .begin () + probs_pos, llama.generated_token_probs .begin () + probs_stop_pos);
967+                         }
968+                         sent_token_probs_index = probs_stop_pos;
969+                     }
970+ 
881971                    const  json data = llama.has_next_token 
882-                                           ? format_partial_response (to_send)
972+                                           ? format_partial_response (llama,  to_send, probs_output )
883973                                          //  Generation is done, send extra information.
884-                                           : format_final_response (llama, to_send);
974+                                           : format_final_response (llama, to_send, llama. generated_token_probs );
885975
886976                    const  std::string str =
887977                        " data: "   +
0 commit comments