@@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
401401 return true ;
402402}
403403
404+ struct llava_embd_batch {
405+ std::vector<llama_pos> pos;
406+ std::vector<int32_t > n_seq_id;
407+ std::vector<llama_seq_id> seq_id_0;
408+ std::vector<llama_seq_id *> seq_ids;
409+ std::vector<int8_t > logits;
410+ llama_batch batch;
411+ llava_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412+ pos .resize (n_tokens);
413+ n_seq_id.resize (n_tokens);
414+ seq_ids .resize (n_tokens + 1 );
415+ logits .resize (n_tokens);
416+ seq_id_0.resize (1 );
417+ seq_id_0[0 ] = seq_id;
418+ seq_ids [n_tokens] = nullptr ;
419+ batch = {
420+ /* n_tokens =*/ n_tokens,
421+ /* tokens =*/ nullptr ,
422+ /* embd =*/ embd,
423+ /* pos =*/ pos.data (),
424+ /* n_seq_id =*/ n_seq_id.data (),
425+ /* seq_id =*/ seq_ids.data (),
426+ /* logits =*/ logits.data (),
427+ };
428+ for (int i = 0 ; i < n_tokens; i++) {
429+ batch.pos [i] = pos_0 + i;
430+ batch.n_seq_id [i] = 1 ;
431+ batch.seq_id [i] = seq_id_0.data ();
432+ batch.logits [i] = false ;
433+ }
434+ }
435+ };
436+
404437bool llava_eval_image_embed (llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
405438 int n_embd = llama_n_embd (llama_get_model (ctx_llama));
406439
@@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
409442 if (n_eval > n_batch) {
410443 n_eval = n_batch;
411444 }
412- llama_batch batch = {int32_t (n_eval), nullptr , (image_embed->embed +i*n_embd), nullptr , nullptr , nullptr , nullptr , *n_past, 1 , 0 , };
413- if (llama_decode (ctx_llama, batch)) {
445+ float * embd = image_embed->embed +i*n_embd;
446+ llava_embd_batch llava_batch = llava_embd_batch (embd, n_eval, *n_past, 0 );
447+ if (llama_decode (ctx_llama, llava_batch.batch )) {
414448 LOG_ERR (" %s : failed to eval\n " , __func__);
415449 return false ;
416450 }
0 commit comments