@@ -181,54 +181,45 @@ int main(int argc, char ** argv) {
181181        GGML_ASSERT (ids.size () > 0 ); //  there will always be at least one accepted token
182182
183183        n_past    += ids.size () - 1 ;
184-         n_drafted += batch_tgt. n_tokens  -  1 ; 
184+         n_drafted += draft. size ();  //  note: we ignore the discarded small drafts 
185185        n_accept  += ids.size () - 1 ;
186+         n_predict += ids.size ();
186187
187188        //  process the accepted tokens and update contexts
188189        // 
189190        //  this is the standard token post-processing that we normally do
190191        //  in this case, we do it for a group of accepted tokens at once
191192        // 
192-         {
193-             llama_token id;
194-             std::string token_str;
195- 
196-             for  (size_t  i = 0 ; i < ids.size (); ++i) {
197-                 id = ids[i];
198- 
199-                 ++n_predict;
200- 
201-                 if  (llama_token_is_eog (model_tgt, id)) {
202-                     has_eos = true ;
203-                     break ;
204-                 }
193+         for  (size_t  i = 0 ; i < ids.size (); ++i) {
194+             const  llama_token id = ids[i];
205195
206-                 token_str = common_token_to_piece (ctx_tgt, id);
207- 
208-                 if  (params.use_color  && i + 1  < ids.size ()) {
209-                     LOG (" \u001b [%dm%s\u001b [37m" 36  - 0  % 6 ), token_str.c_str ());
210-                 } else  {
211-                     LOG (" %s" c_str ());
212-                 }
213-             }
196+             prompt_tgt.push_back (id_last);
197+             id_last = id;
214198
215-             if  ((params.n_predict  >= 0  && n_predict > params.n_predict ) || has_eos) {
199+             if  (llama_token_is_eog (model_tgt, id)) {
200+                 has_eos = true ;
216201                break ;
217202            }
218203
219-             LOG_DBG ( " accepted %d/%d draft tokens, the last target token is: (%d, '%s') \n " , ( int ) ids. size () -  1 , ( int ) draft. size (), id, token_str. c_str () );
204+             const  std::string token_str =  common_token_to_piece (ctx_tgt, id );
220205
221-             {
222-                 LOG_DBG ( " clear kv cache from any extra tokens, n_past = %d \n " , n_past );
223- 
224-                 llama_kv_cache_seq_rm (ctx_tgt,  0 , n_past, - 1 );
206+             if  (params. use_color  && i +  1  < ids. size ())  {
207+                 LOG ( " \u001b [%dm%s \u001b [37m " , ( 36  -  0  %  6 ), token_str. c_str () );
208+             }  else  { 
209+                 LOG ( " %s " , token_str. c_str () );
225210            }
211+         }
226212
227-             prompt_tgt.push_back (id_last);
228-             prompt_tgt.insert (prompt_tgt.end (), ids.begin (), ids.end () - 1 );
213+         LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d)\n " int ) ids.size () - 1 , (int ) draft.size (), id_last);
229214
230-             //  remember the last accepted token for the next iteration
231-             id_last = id;
215+         {
216+             LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " 
217+ 
218+             llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, -1 );
219+         }
220+ 
221+         if  ((params.n_predict  >= 0  && n_predict > params.n_predict ) || has_eos) {
222+             break ;
232223        }
233224    }
234225
0 commit comments