@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
3535 auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n , 0 );
3636
3737 // init
38- auto ctx = llama_init_from_file (params.model .c_str (), lparams);
38+ auto model = llama_load_model_from_file (params.model .c_str (), lparams);
39+ if (model == nullptr ) {
40+ return 1 ;
41+ }
42+ auto ctx = llama_new_context_with_model (model, lparams);
43+ if (ctx == nullptr ) {
44+ llama_free_model (model);
45+ return 1 ;
46+ }
3947 auto tokens = std::vector<llama_token>(params.n_ctx );
4048 auto n_prompt_tokens = llama_tokenize (ctx, params.prompt .c_str (), tokens.data (), int (tokens.size ()), true );
4149
4250 if (n_prompt_tokens < 1 ) {
4351 fprintf (stderr, " %s : failed to tokenize prompt\n " , __func__);
52+ llama_free (ctx);
53+ llama_free_model (model);
4454 return 1 ;
4555 }
4656
@@ -84,30 +94,36 @@ int main(int argc, char ** argv) {
8494 printf (" %s" , next_token_str);
8595 if (llama_eval (ctx, &next_token, 1 , n_past, params.n_threads )) {
8696 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
97+ llama_free (ctx);
98+ llama_free_model (model);
8799 return 1 ;
88100 }
89101 n_past += 1 ;
90102 }
91103
92104 printf (" \n\n " );
93105
94- // free old model
106+ // free old context
95107 llama_free (ctx);
96108
97- // load new model
98- auto ctx2 = llama_init_from_file (params. model . c_str () , lparams);
109+ // make new context
110+ auto ctx2 = llama_new_context_with_model ( model, lparams);
99111
100112 // Load state (rng, logits, embedding and kv_cache) from file
101113 {
102114 FILE *fp_read = fopen (" dump_state.bin" , " rb" );
103115 if (state_size != llama_get_state_size (ctx2)) {
104116 fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
117+ llama_free (ctx2);
118+ llama_free_model (model);
105119 return 1 ;
106120 }
107121
108122 const size_t ret = fread (state_mem, 1 , state_size, fp_read);
109123 if (ret != state_size) {
110124 fprintf (stderr, " \n %s : failed to read state\n " , __func__);
125+ llama_free (ctx2);
126+ llama_free_model (model);
111127 return 1 ;
112128 }
113129
@@ -138,12 +154,17 @@ int main(int argc, char ** argv) {
138154 printf (" %s" , next_token_str);
139155 if (llama_eval (ctx2, &next_token, 1 , n_past, params.n_threads )) {
140156 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
157+ llama_free (ctx2);
158+ llama_free_model (model);
141159 return 1 ;
142160 }
143161 n_past += 1 ;
144162 }
145163
146164 printf (" \n\n " );
147165
166+ llama_free (ctx2);
167+ llama_free_model (model);
168+
148169 return 0 ;
149170}
0 commit comments