@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728 }
729729
730730 // note: during encode, we always pass the full sequence starting from pos = 0
731- if (!batch_allocr->init (batch_inp, model.vocab , nullptr )) {
731+ if (!batch_allocr->init (batch_inp, model.vocab , nullptr , true )) {
732732 LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
733733 return -1 ;
734734 }
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894894 return -1 ;
895895 }
896896
897- if (!batch_allocr->init (batch_inp, model.vocab , memory.get ())) {
897+ // when computing embeddings, all tokens are output
898+ const bool embd_all = cparams.embeddings ;
899+
900+ if (!batch_allocr->init (batch_inp, model.vocab , memory.get (), embd_all)) {
898901 LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
899902 return -1 ;
900903 }
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911914
912915 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
913916
914- // this indicates we are doing pooled embedding
915- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916-
917917 const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
918918
919- if (embd_pooled ) {
919+ if (embd_all ) {
920920 // require that all tokens are output
921921 if (n_outputs_all != n_tokens_all) {
922922 LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " ,
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945945 llama_memory_state_ptr mstate;
946946
947947 while (true ) {
948- mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
948+ mstate = memory->init_batch (batch, cparams.n_ubatch , embd_all );
949949 if (!mstate) {
950950 return -2 ;
951951 }
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581058 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591059 // }
10601060
1061- auto * t_logits = cparams. embeddings ? nullptr : res->get_logits ();
1061+ auto * t_logits = res->get_logits ();
10621062 auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
10631063
10641064 if (t_embd && res->get_embd_pooled ()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221222 const auto n_vocab = vocab.n_tokens ();
12231223 const auto n_embd = hparams.n_embd ;
12241224
1225- // TODO: use a per-batch flag for logits presence instead
1226- bool has_logits = !cparams.embeddings ;
1227- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225+ bool has_logits = true ;
1226+ bool has_embd = cparams.embeddings ;
12281227
12291228 // TODO: hacky enc-dec support
12301229 if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
20442043
20452044 n_queued_tokens += n_tokens_all;
20462045
2047- // this indicates we are doing pooled embedding
2048- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049-
20502046 embd_seq.clear ();
20512047
20522048 uint32_t n_outputs_all = n_tokens_all;
20532049
2054- auto mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
2050+ auto mstate = memory->init_batch (batch, cparams.n_ubatch , true );
20552051 if (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20562052 LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
20572053 break ;
0 commit comments