@@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
325325 double nll = 0.0 ;
326326 double nll2 = 0.0 ;
327327
328+ const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
329+
330+ std::vector<float > logits;
331+ if (num_batches > 1 ) {
332+ logits.reserve ((size_t )n_ctx * n_vocab);
333+ }
334+
328335 fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
329336
330337 std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
@@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
333340 const int start = i * n_ctx;
334341 const int end = start + n_ctx;
335342
336- const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
337-
338- std::vector<float > logits;
339-
340343 const auto t_start = std::chrono::high_resolution_clock::now ();
341344
342345 // clear the KV cache
@@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
362365 // restore the original token in case it was set to BOS
363366 tokens[batch_start] = token_org;
364367
365- const auto * batch_logits = llama_get_logits (ctx);
366- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
368+ if (num_batches > 1 ) {
369+ const auto * batch_logits = llama_get_logits (ctx);
370+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
371+ }
367372 }
368373
369374 const auto t_end = std::chrono::high_resolution_clock::now ();
@@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
392397 // last 256 tokens. Then, we split the input up into context window size chunks to
393398 // process the entire prompt.
394399 const int first = n_ctx/2 ;
395- process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
400+ const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
401+ process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
396402 workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
397403 count += n_ctx - first - 1 ;
398404
@@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
406412 printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
407413 }
408414 fflush (stdout);
415+
416+ logits.clear ();
409417 }
410418 printf (" \n " );
411419
0 commit comments