@@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
523523
524524 const llama_token id = llama_sampling_sample (ctx_sampling, ctx, ctx_guidance);
525525
526- llama_sampling_accept (ctx_sampling, ctx, id);
526+ llama_sampling_accept (ctx_sampling, ctx, id, true );
527527
528528 LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, ctx_sampling->prev ).c_str ());
529529
@@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
541541 LOG (" embd_inp.size(): %d, n_consumed: %d\n " , (int ) embd_inp.size (), n_consumed);
542542 while ((int ) embd_inp.size () > n_consumed) {
543543 embd.push_back (embd_inp[n_consumed]);
544- ctx_sampling->prev .erase (ctx_sampling->prev .begin ());
545- ctx_sampling->prev .push_back (embd_inp[n_consumed]);
544+
545+ // push the prompt in the sampling context in order to apply repetition penalties later
546+ // for the prompt, we don't apply grammar rules
547+ llama_sampling_accept (ctx_sampling, ctx, embd_inp[n_consumed], false );
548+
546549 ++n_consumed;
547550 if ((int ) embd.size () >= params.n_batch ) {
548551 break ;
@@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
574577 if ((int ) embd_inp.size () <= n_consumed) {
575578
576579 // deal with eot token in infill mode
577- if ((ctx_sampling-> prev . back ( ) == llama_token_eot (ctx) || is_interacting) && params.interactive ){
580+ if ((llama_sampling_last (ctx_sampling ) == llama_token_eot (ctx) || is_interacting) && params.interactive ){
578581 if (is_interacting && !params.interactive_first ) {
579582 // print an eot token
580583 printf (" %s" , llama_token_to_piece (ctx, llama_token_eot (ctx)).c_str ());
@@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
591594 buffer += line;
592595 } while (another_line);
593596 // check if we got an empty line, if so we use the old input
594- if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
597+ if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
595598 params.input_prefix = buffer;
596599 }
597600 buffer.clear ();
@@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
601604 buffer += line;
602605 } while (another_line);
603606 // check if we got an empty line
604- if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
607+ if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
605608 params.input_suffix = buffer;
606609 }
607610 buffer.clear ();
@@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
614617 process_escapes (params.input_suffix );
615618 }
616619 suff_rm_leading_spc = params.escape ;
617- if (suff_rm_leading_spc && params.input_suffix .find_first_of (" " ) == 0 && params.input_suffix .size () > 1 ) {
620+ if (suff_rm_leading_spc && params.input_suffix .find_first_of (' ' ) == 0 && params.input_suffix .size () > 1 ) {
618621 params.input_suffix .erase (0 , 1 );
619622 suff_rm_leading_spc = false ;
620623 }
@@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
641644 is_interacting = false ;
642645 }
643646 // deal with end of text token in interactive mode
644- else if (ctx_sampling-> prev . back ( ) == llama_token_eos (ctx)) {
647+ else if (llama_sampling_last (ctx_sampling ) == llama_token_eos (ctx)) {
645648 LOG (" found EOS token\n " );
646649
647650 if (params.interactive ) {
0 commit comments