File tree Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Original file line number Diff line number Diff line change @@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166
167167                //  note: tracking the other way around is not necessary for now
168168                // seq_cpl[s0][s1] = true;
169+ 
170+                 has_cpl = true ;
169171            }
170172        }
171173    }
@@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
466468    return  ubatch_add (idxs, idxs.size (), false );
467469}
468470
469- llama_ubatch llama_batch_allocr::split_equal (uint32_t  n_ubatch) {
471+ llama_ubatch llama_batch_allocr::split_equal (uint32_t  n_ubatch, bool  sequential) {
472+     if  (sequential && has_cpl) {
473+         LLAMA_LOG_ERROR (" %s: sequential split is not supported when there are coupled sequences in the input batch\n "  , __func__);
474+ 
475+         return  {};
476+     }
477+ 
470478    std::vector<seq_set_t > cur_seq_set;
471479
480+     llama_seq_id last_seq_id = -1 ;
481+ 
472482    //  determine the non-overlapping sequence sets participating in this ubatch
473483    for  (int32_t  i = 0 ; i < batch.n_tokens ; ++i) {
474484        if  (used[i]) {
@@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
485495            }
486496        }
487497
498+         //  accept only increasing sequence ids
499+         if  (sequential) {
500+             add = add && (cur_seq_set.empty () || batch.seq_id [i][0 ] == last_seq_id + 1 );
501+         }
502+ 
488503        if  (add) {
489504            cur_seq_set.push_back (seq_set[i]);
490505
506+             last_seq_id = batch.seq_id [i][0 ];
507+ 
491508            if  (cur_seq_set.size () > n_ubatch) {
492509                break ;
493510            }
Original file line number Diff line number Diff line change @@ -70,7 +70,8 @@ class llama_batch_allocr {
7070    llama_ubatch split_simple (uint32_t  n_ubatch);
7171
7272    //  make ubatches of equal-length sequences sets
73-     llama_ubatch split_equal (uint32_t  n_ubatch);
73+     //  if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
74+     llama_ubatch split_equal (uint32_t  n_ubatch, bool  sequential);
7475
7576    //  sequence-set-wise split - each ubatch contains a single sequence-set
7677    llama_ubatch split_seq (uint32_t  n_ubatch);
@@ -113,6 +114,9 @@ class llama_batch_allocr {
113114    using  pos_set_t  = std::set<llama_pos>;
114115    using  seq_cpl_t  = std::vector<bool >;
115116
117+     //  helper flag to quickly determine if there are any coupled sequences in the batch
118+     bool  has_cpl;
119+ 
116120    std::vector<pos_set_t > seq_pos; //  seq_pos[s]: the set of positions in sequence s
117121    std::vector<seq_cpl_t > seq_cpl; //  seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
118122
Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140140
141141        std::vector<llama_ubatch> ubatches;
142142        while  (true ) {
143-             auto  ubatch = balloc.split_equal (n_ubatch);
143+             auto  ubatch = balloc.split_equal (n_ubatch,  false );
144144
145145            if  (ubatch.n_tokens  == 0 ) {
146146                break ;
Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
7070                //  if all tokens are output, split by sequence
7171                ubatch = balloc.split_seq (n_ubatch);
7272            } else  {
73-                 ubatch = balloc.split_equal (n_ubatch);
73+                 ubatch = balloc.split_equal (n_ubatch,  false );
7474            }
7575
7676            if  (ubatch.n_tokens  == 0 ) {
Original file line number Diff line number Diff line change @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374374                //  if all tokens are output, split by sequence
375375                ubatch = balloc.split_seq (n_ubatch);
376376            } else  {
377-                 ubatch = balloc.split_equal (n_ubatch);
377+                 ubatch = balloc.split_equal (n_ubatch,  false );
378378            }
379379
380380            if  (balloc.get_n_used () < balloc.get_n_tokens ()) {
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments