@@ -282,7 +282,7 @@ bool llama_batch_allocr::init(
282282        }
283283    }
284284
285-     //  disallow disjoint  sequence sets:
285+     //  disallow partial  sequence sub- sets:
286286    // 
287287    //  invalid:          x
288288    //             i: 0 1 2 ...
@@ -291,28 +291,46 @@ bool llama_batch_allocr::init(
291291    //  seq_id[i][1]: 1 1 2
292292    //  seq_id[i][2]: 2
293293    // 
294+     //  disallow decreasing sequence positions:
295+     // 
296+     //  invalid:                  x
297+     //             i: 0 1 2 3 4 5 6 ...
298+     //  ---------------------------------------
299+     //        pos[i]: 4 5 0 1 6 2 3
300+     //  seq_id[i][0]: 0 0 1 1 0 1 0
301+     // 
294302    {
295303        seq_set_t  cur_seq_set[LLAMA_MAX_SEQ];
296304        for  (int32_t  s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
297305            cur_seq_set[s].set ();
298306        }
299307
308+         llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
309+         for  (int32_t  s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
310+             cur_seq_pos[s] = -1 ;
311+         }
312+ 
300313        for  (int32_t  i = 0 ; i < batch.n_tokens ; ++i) {
314+             const  llama_pos pos = batch.pos [i];
315+ 
301316            for  (int32_t  s = 0 ; s < batch.n_seq_id [i]; ++s) {
302317                const  llama_seq_id seq_id = batch.seq_id [i][s];
303318
304319                cur_seq_set[seq_id] &= seq_set[i];
305320
306321                if  (cur_seq_set[seq_id].none ()) {
307-                     LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets\n " 
322+                     LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets (not allowed)\n " 
323+                     return  false ;
324+                 }
325+ 
326+                 if  (pos < cur_seq_pos[seq_id]) {
327+                     LLAMA_LOG_ERROR (" %s: sequence %d positions are decreasing (not allowed)\n " 
308328                    return  false ;
309329                }
310330            }
311331        }
312332    }
313333
314-     //  TODO: check that positions are increasing
315- 
316334    split_reset ();
317335
318336    return  true ;
0 commit comments