@@ -527,7 +527,8 @@ bool llama_eval(
527527 const int n_past,
528528 const std::vector<gpt_vocab::id> & embd_inp,
529529 std::vector<float > & embd_w,
530- size_t & mem_per_token) {
530+ size_t & mem_per_token,
531+ bool return_all_logits = false ) {
531532 const int N = embd_inp.size ();
532533
533534 const auto & hparams = model.hparams ;
@@ -733,9 +734,14 @@ bool llama_eval(
733734 // embd_w.resize(n_vocab*N);
734735 // memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
735736
736- // return result for just the last token
737- embd_w.resize (n_vocab);
738- memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
737+ if (return_all_logits) {
738+ embd_w.resize (n_vocab * N);
739+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL), sizeof (float )*n_vocab*N);
740+ } else {
741+ // return result for just the last token
742+ embd_w.resize (n_vocab);
743+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
744+ }
739745
740746 if (mem_per_token == 0 ) {
741747 mem_per_token = ggml_used_mem (ctx0)/N;
@@ -769,22 +775,42 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para
769775 // Output: `perplexity: 13.5106 [114/114]`
770776 std::vector<gpt_vocab::id> tokens = ::llama_tokenize (vocab, params.prompt , true );
771777
778+ int count = 0 ;
772779 double nll = 0.0 ;
773780 int seq_count = tokens.size () / params.n_ctx ;
774781 for (int i = 0 ; i < seq_count; ++i) {
775782 int start = i * params.n_ctx ;
776783 int end = start + params.n_ctx - 1 ;
777784 std::vector<gpt_vocab::id> embd (tokens.begin () + start, tokens.begin () + end);
778785 std::vector<float > logits;
779- if (!llama_eval (model, params.n_threads , 0 , embd, logits, mem_per_token)) {
786+ if (!llama_eval (model, params.n_threads , 0 , embd, logits, mem_per_token, true )) {
780787 fprintf (stderr, " Failed to predict\n " );
781788 return ;
782789 }
783- // Calculate probability of next token, given the previous ones.
784- double prob = softmax (logits)[tokens[end]];
785- nll += -std::log (prob);
790+ // We get the logits for all the tokens in the context window (params.n_ctx)
791+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
792+ // calculate the perplexity over the last half the window (so the model always has
793+ // some context to predict the token).
794+ //
795+ // We rely on the fact that attention in the forward pass only looks at previous
796+ // tokens here, so the logits returned for each token are an accurate representation
797+ // of what the model would have predicted at that point.
798+ //
799+ // Example, we have a context window of 512, we will compute perplexity for each of the
800+ // last 256 tokens. Then, we split the input up into context window size chunks to
801+ // process the entire prompt.
802+ for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
803+ // Calculate probability of next token, given the previous ones.
804+ int n_vocab = model.hparams .n_vocab ;
805+ std::vector<float > tok_logits (
806+ logits.begin () + j * n_vocab,
807+ logits.begin () + (j + 1 ) * n_vocab);
808+ double prob = softmax (tok_logits)[tokens[start + j + 1 ]];
809+ nll += -std::log (prob);
810+ ++count;
811+ }
786812 // perplexity is e^(average negative log-likelihood)
787- printf (" perplexity: %.4lf [%d/%d] \r " , std::exp (nll / (i + 1 ) ), i + 1 , seq_count);
813+ printf (" perplexity: %.4lf [%d/%d] \r " , std::exp (nll / count ), i + 1 , seq_count);
788814 fflush (stdout);
789815 }
790816 printf (" \n " );
0 commit comments