@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
363363}
364364
365365llama_memory_context_ptr llama_memory_recurrent::init_batch (llama_batch_allocr & balloc, uint32_t  n_ubatch, bool  embd_all) {
366-     std::vector<llama_ubatch> ubatches;
366+     do  {
367+         balloc.split_reset ();
367368
368-     while  (true ) {
369-         llama_ubatch ubatch;
369+         std::vector<llama_ubatch> ubatches;
370+         while  (true ) {
371+             llama_ubatch ubatch;
370372
371-         if  (embd_all) {
372-             //  if all tokens are output, split by sequence
373-             ubatch = balloc.split_seq (n_ubatch);
374-         } else  {
375-             ubatch = balloc.split_equal (n_ubatch);
373+             if  (embd_all) {
374+                 //  if all tokens are output, split by sequence
375+                 ubatch = balloc.split_seq (n_ubatch);
376+             } else  {
377+                 ubatch = balloc.split_equal (n_ubatch);
378+             }
379+ 
380+             if  (ubatch.n_tokens  == 0 ) {
381+                 break ;
382+             }
383+ 
384+             ubatches.push_back (std::move (ubatch)); //  NOLINT
376385        }
377386
378-         if  (ubatch. n_tokens  ==  0 ) {
387+         if  (! prepare (ubatches) ) {
379388            break ;
380389        }
381390
382-         ubatches.push_back (std::move (ubatch)); //  NOLINT
383-     }
384- 
385-     if  (!prepare (ubatches)) {
386-         return  std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387-     }
391+         return  std::make_unique<llama_memory_recurrent_context>(this , std::move (ubatches));
392+     } while  (false );
388393
389-     return  std::make_unique<llama_memory_recurrent_context>(this ,  std::move (ubatches) );
394+     return  std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
390395}
391396
392397llama_memory_context_ptr llama_memory_recurrent::init_full () {
0 commit comments