@@ -176,8 +176,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
176176 }
177177 }
178178
179- const ggml_type wtype2 = GGML_TYPE_F32;
180-
181179 auto & ctx = model.ctx ;
182180
183181 size_t ctx_size = 0 ;
@@ -237,7 +235,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
237235
238236 const int n_embd = hparams.n_embd ;
239237 const int n_layer = hparams.n_layer ;
240- const int n_ctx = hparams.n_ctx ;
241238 const int n_vocab = hparams.n_vocab ;
242239
243240 model.layers .resize (n_layer);
@@ -539,9 +536,7 @@ bool llama_eval(
539536 const int n_vocab = hparams.n_vocab ;
540537 const int n_rot = hparams.n_embd /hparams.n_head ;
541538
542- const int d_key = n_embd/n_head;
543-
544- // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
539+ // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
545540 // static size_t buf_size = hparams.n_ctx*1024*1024;
546541 static size_t buf_size = 512u *1024 *1024 ;
547542 static void * buf = malloc (buf_size);
@@ -792,7 +787,7 @@ int main(int argc, char ** argv) {
792787 if (gpt_params_parse (argc, argv, params) == false ) {
793788 return 1 ;
794789 }
795-
790+
796791 if (params.n_ctx > 2048 ) {
797792 fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
798793 " expect poor results\n " , __func__, params.n_ctx );
@@ -820,7 +815,7 @@ int main(int argc, char ** argv) {
820815 // load the model
821816 {
822817 const int64_t t_start_us = ggml_time_us ();
823- if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
818+ if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
824819 fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
825820 return 1 ;
826821 }
@@ -849,9 +844,25 @@ int main(int argc, char ** argv) {
849844
850845 params.n_predict = std::min (params.n_predict , model.hparams .n_ctx - (int ) embd_inp.size ());
851846
847+ // prefix & suffix for instruct mode
848+ const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize (vocab, " \n\n ### Instruction:\n\n " , true );
849+ const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize (vocab, " \n\n ### Response:\n\n " , false );
850+
851+ // in instruct mode, we inject a prefix and a suffix to each input by the user
852+ if (params.instruct ) {
853+ fprintf (stderr, " == Instruction mode enabled ==\n " );
854+ params.interactive = true ;
855+ params.antiprompt = " ### Instruction:\n\n " ;
856+ }
857+
852858 // tokenize the reverse prompt
853859 std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize (vocab, params.antiprompt , false );
854860
861+ // enable interactive mode if reverse prompt is specified
862+ if (!antiprompt_inp.empty ()) {
863+ params.interactive = true ;
864+ }
865+
855866 fprintf (stderr, " \n " );
856867 fprintf (stderr, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
857868 fprintf (stderr, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
@@ -872,7 +883,7 @@ int main(int argc, char ** argv) {
872883
873884 fprintf (stderr, " %s: interactive mode on.\n " , __func__);
874885
875- if (antiprompt_inp.size ()) {
886+ if (antiprompt_inp.size ()) {
876887 fprintf (stderr, " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .c_str ());
877888 fprintf (stderr, " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
878889 for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
@@ -894,31 +905,27 @@ int main(int argc, char ** argv) {
894905 std::vector<gpt_vocab::id> last_n_tokens (last_n_size);
895906 std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
896907
897-
898908 if (params.interactive ) {
899909 fprintf (stderr, " == Running in interactive mode. ==\n "
900910#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
901911 " - Press Ctrl+C to interject at any time.\n "
902912#endif
903913 " - Press Return to return control to LLaMa.\n "
904- " - If you want to submit another line, end your input in '\\ '.\n " );
914+ " - If you want to submit another line, end your input in '\\ '.\n\n " );
915+ is_interacting = true ;
905916 }
906917
907- int remaining_tokens = params.n_predict ;
908918 int input_consumed = 0 ;
909919 bool input_noecho = false ;
910920
911- // prompt user immediately after the starting prompt has been loaded
912- if (params.interactive_start ) {
913- is_interacting = true ;
914- }
921+ int remaining_tokens = params.n_predict ;
915922
916923 // set the color for the prompt which will be output initially
917924 if (params.use_color ) {
918925 printf (ANSI_COLOR_YELLOW);
919926 }
920927
921- while (remaining_tokens > 0 ) {
928+ while (remaining_tokens > 0 || params. interactive ) {
922929 // predict
923930 if (embd.size () > 0 ) {
924931 const int64_t t_start_us = ggml_time_us ();
@@ -971,13 +978,13 @@ int main(int argc, char ** argv) {
971978 last_n_tokens.erase (last_n_tokens.begin ());
972979 last_n_tokens.push_back (embd_inp[input_consumed]);
973980 ++input_consumed;
974- if (embd.size () > params.n_batch ) {
981+ if (( int ) embd.size () > params.n_batch ) {
975982 break ;
976983 }
977984 }
978985
979986 // reset color to default if we there is no pending user input
980- if (!input_noecho && params.use_color && embd_inp.size () == input_consumed) {
987+ if (!input_noecho && params.use_color && ( int ) embd_inp.size () == input_consumed) {
981988 printf (ANSI_COLOR_RESET);
982989 }
983990 }
@@ -999,19 +1006,26 @@ int main(int argc, char ** argv) {
9991006 is_interacting = true ;
10001007 }
10011008 if (is_interacting) {
1009+ if (params.instruct ) {
1010+ input_consumed = embd_inp.size ();
1011+ embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
1012+
1013+ printf (" \n > " );
1014+ }
1015+
10021016 // currently being interactive
1003- bool another_line= true ;
1017+ bool another_line = true ;
10041018 while (another_line) {
10051019 fflush (stdout);
10061020 char buf[256 ] = {0 };
10071021 int n_read;
1008- if (params.use_color ) printf (ANSI_BOLD ANSI_COLOR_GREEN);
1022+ if (params.use_color ) printf (ANSI_BOLD ANSI_COLOR_GREEN);
10091023 if (scanf (" %255[^\n ]%n%*c" , buf, &n_read) <= 0 ) {
10101024 // presumable empty line, consume the newline
10111025 std::ignore = scanf (" %*c" );
10121026 n_read=0 ;
10131027 }
1014- if (params.use_color ) printf (ANSI_COLOR_RESET);
1028+ if (params.use_color ) printf (ANSI_COLOR_RESET);
10151029
10161030 if (n_read > 0 && buf[n_read-1 ]==' \\ ' ) {
10171031 another_line = true ;
@@ -1026,6 +1040,10 @@ int main(int argc, char ** argv) {
10261040 std::vector<gpt_vocab::id> line_inp = ::llama_tokenize (vocab, buf, false );
10271041 embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
10281042
1043+ if (params.instruct ) {
1044+ embd_inp.insert (embd_inp.end (), inp_sfx.begin (), inp_sfx.end ());
1045+ }
1046+
10291047 remaining_tokens -= line_inp.size ();
10301048
10311049 input_noecho = true ; // do not echo this again
@@ -1037,8 +1055,12 @@ int main(int argc, char ** argv) {
10371055
10381056 // end of text token
10391057 if (embd.back () == 2 ) {
1040- fprintf (stderr, " [end of text]\n " );
1041- break ;
1058+ if (params.interactive ) {
1059+ is_interacting = true ;
1060+ } else {
1061+ fprintf (stderr, " [end of text]\n " );
1062+ break ;
1063+ }
10421064 }
10431065 }
10441066
0 commit comments