@@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
117117 llama_token id_last = inp.back ();
118118
119119 // all tokens currently in the target context
120- auto prompt_tgt = std::vector<llama_token>(inp.begin (), inp.end () - 1 );
120+ llama_tokens prompt_tgt (inp.begin (), inp.end () - 1 );
121+ prompt_tgt.reserve (llama_n_ctx (ctx_tgt));
121122
122123 int n_past = inp.size () - 1 ;
123124
@@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
181182 GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
182183
183184 n_past += ids.size () - 1 ;
184- n_drafted += batch_tgt. n_tokens - 1 ;
185+ n_drafted += draft. size (); // note: we ignore the discarded small drafts
185186 n_accept += ids.size () - 1 ;
187+ n_predict += ids.size ();
186188
187189 // process the accepted tokens and update contexts
188190 //
189191 // this is the standard token post-processing that we normally do
190192 // in this case, we do it for a group of accepted tokens at once
191193 //
192- {
193- llama_token id;
194- std::string token_str;
195-
196- for (size_t i = 0 ; i < ids.size (); ++i) {
197- id = ids[i];
198-
199- ++n_predict;
200-
201- if (llama_token_is_eog (model_tgt, id)) {
202- has_eos = true ;
203- break ;
204- }
205-
206- token_str = common_token_to_piece (ctx_tgt, id);
194+ for (size_t i = 0 ; i < ids.size (); ++i) {
195+ prompt_tgt.push_back (id_last);
207196
208- if (params.use_color && i + 1 < ids.size ()) {
209- LOG (" \u001b [%dm%s\u001b [37m" , (36 - 0 % 6 ), token_str.c_str ());
210- } else {
211- LOG (" %s" , token_str.c_str ());
212- }
213- }
197+ id_last = ids[i];
214198
215- if ((params.n_predict >= 0 && n_predict > params.n_predict ) || has_eos) {
199+ if (llama_token_is_eog (model_tgt, id_last)) {
200+ has_eos = true ;
216201 break ;
217202 }
218203
219- LOG_DBG ( " accepted %d/%d draft tokens, the last target token is: (%d, '%s') \n " , ( int ) ids. size () - 1 , ( int ) draft. size (), id, token_str. c_str () );
204+ const std::string token_str = common_token_to_piece (ctx_tgt, id_last );
220205
221- {
222- LOG_DBG ( " clear kv cache from any extra tokens, n_past = %d \n " , n_past );
223-
224- llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, - 1 );
206+ if (params. use_color && i + 1 < ids. size ()) {
207+ LOG ( " \u001b [%dm%s \u001b [37m " , ( 36 - 0 % 6 ), token_str. c_str () );
208+ } else {
209+ LOG ( " %s " , token_str. c_str () );
225210 }
211+ }
226212
227- prompt_tgt.push_back (id_last);
228- prompt_tgt.insert (prompt_tgt.end (), ids.begin (), ids.end () - 1 );
213+ LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d)\n " , (int ) ids.size () - 1 , (int ) draft.size (), id_last);
214+
215+ {
216+ LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
217+
218+ llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, -1 );
219+ }
229220
230- // remember the last accepted token for the next iteration
231- id_last = id ;
221+ if ((params. n_predict >= 0 && n_predict > params. n_predict ) || has_eos) {
222+ break ;
232223 }
233224 }
234225
0 commit comments