@@ -5179,6 +5179,57 @@ struct llama_model_loader {
51795179 }
51805180};
51815181
5182+ // temporary allocate memory for the input batch if needed
5183+ static const llama_seq_id batch_default_seq_id = 0;
5184+ struct llama_batch_allocr {
5185+ std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
5186+ std::vector<llama_pos> pos;
5187+ std::vector<int32_t> n_seq_id;
5188+ std::vector<llama_seq_id *> seq_id;
5189+ std::vector<int8_t> logits;
5190+ struct llama_batch batch;
5191+ // optionally fulfill the batch returned by llama_batch_get_one
5192+ llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
5193+ batch = in_batch;
5194+ GGML_ASSERT(batch.n_tokens > 0);
5195+ if (!batch.pos) {
5196+ // determine the last position in KV cache
5197+ llama_pos last_pos = -1;
5198+ for (const auto & cell : ctx.kv_self.cells) {
5199+ if (cell.has_seq_id(batch_default_seq_id)) {
5200+ last_pos = std::max(last_pos, cell.pos);
5201+ }
5202+ }
5203+ last_pos++; // next position
5204+ pos.resize(batch.n_tokens);
5205+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5206+ pos[i] = i+last_pos;
5207+ }
5208+ batch.pos = pos.data();
5209+ }
5210+ if (!batch.n_seq_id) {
5211+ n_seq_id.resize(batch.n_tokens);
5212+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5213+ n_seq_id[i] = seq_id_0.size();
5214+ }
5215+ batch.n_seq_id = n_seq_id.data();
5216+ }
5217+ if (!batch.seq_id) {
5218+ seq_id.resize(batch.n_tokens + 1);
5219+ seq_id[batch.n_tokens] = NULL;
5220+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5221+ seq_id[i] = seq_id_0.data();
5222+ }
5223+ batch.seq_id = seq_id.data();
5224+ }
5225+ if (!batch.logits) {
5226+ logits.resize(batch.n_tokens);
5227+ logits[logits.size() - 1] = true;
5228+ batch.logits = logits.data();
5229+ }
5230+ }
5231+ };
5232+
51825233template<>
51835234bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
51845235 uint32_t tmp;
@@ -17101,16 +17152,20 @@ static void llama_graph_compute(
1710117152//
1710217153static int llama_decode_internal(
1710317154 llama_context & lctx,
17104- llama_batch batch ) {
17155+ llama_batch inp_batch ) {
1710517156
1710617157 lctx.is_encoding = false;
17107- const uint32_t n_tokens_all = batch.n_tokens;
1710817158
17109- if (n_tokens_all == 0) {
17159+ if (inp_batch.n_tokens == 0) {
1711017160 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1711117161 return -1;
1711217162 }
1711317163
17164+ // temporary allocate memory for the input batch if needed
17165+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17166+ const llama_batch & batch = batch_allocr.batch;
17167+ const uint32_t n_tokens_all = batch.n_tokens;
17168+
1711417169 const auto & model = lctx.model;
1711517170 const auto & hparams = model.hparams;
1711617171 const auto & cparams = lctx.cparams;
@@ -17415,17 +17470,20 @@ static int llama_decode_internal(
1741517470//
1741617471static int llama_encode_internal(
1741717472 llama_context & lctx,
17418- llama_batch batch ) {
17473+ llama_batch inp_batch ) {
1741917474
1742017475 lctx.is_encoding = true;
1742117476
17422- const uint32_t n_tokens = batch.n_tokens;
17423-
17424- if (n_tokens == 0) {
17477+ if (inp_batch.n_tokens == 0) {
1742517478 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1742617479 return -1;
1742717480 }
1742817481
17482+ // temporary allocate memory for the input batch if needed
17483+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17484+ const llama_batch & batch = batch_allocr.batch;
17485+ const uint32_t n_tokens = batch.n_tokens;
17486+
1742917487 const auto & model = lctx.model;
1743017488 const auto & hparams = model.hparams;
1743117489 const auto & cparams = lctx.cparams;
@@ -21096,61 +21154,10 @@ void llama_batch_free(struct llama_batch batch) {
2109621154 if (batch.logits) free(batch.logits);
2109721155}
2109821156
21099- // temporary allocate memory for the input batch if needed
21100- static const llama_seq_id batch_default_seq_id = 0;
21101- struct llama_batch_allocr {
21102- std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
21103- std::vector<llama_pos> pos;
21104- std::vector<int32_t> n_seq_id;
21105- std::vector<llama_seq_id *> seq_id;
21106- std::vector<int8_t> logits;
21107- struct llama_batch batch;
21108- // optionally fulfill the batch returned by llama_batch_get_one
21109- llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
21110- batch = in_batch;
21111- if (!batch.pos) {
21112- // determine the last position in KV cache
21113- llama_pos last_pos = -1;
21114- for (const auto & cell : ctx->kv_self.cells) {
21115- if (cell.has_seq_id(batch_default_seq_id)) {
21116- last_pos = std::max(last_pos, cell.pos);
21117- }
21118- }
21119- last_pos++; // next position
21120- pos.resize(batch.n_tokens);
21121- for (int32_t i = 0; i < batch.n_tokens; i++) {
21122- pos[i] = i+last_pos;
21123- }
21124- batch.pos = pos.data();
21125- }
21126- if (!batch.n_seq_id) {
21127- n_seq_id.resize(batch.n_tokens);
21128- for (int32_t i = 0; i < batch.n_tokens; i++) {
21129- n_seq_id[i] = seq_id_0.size();
21130- }
21131- batch.n_seq_id = n_seq_id.data();
21132- }
21133- if (!batch.seq_id) {
21134- seq_id.resize(batch.n_tokens + 1);
21135- seq_id[batch.n_tokens] = NULL;
21136- for (int32_t i = 0; i < batch.n_tokens; i++) {
21137- seq_id[i] = seq_id_0.data();
21138- }
21139- batch.seq_id = seq_id.data();
21140- }
21141- if (!batch.logits) {
21142- logits.resize(batch.n_tokens);
21143- logits[logits.size() - 1] = true;
21144- batch.logits = logits.data();
21145- }
21146- }
21147- };
21148-
2114921157int32_t llama_encode(
2115021158 struct llama_context * ctx,
2115121159 struct llama_batch batch) {
21152- llama_batch_allocr batch_allocr(ctx, batch);
21153- const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
21160+ const int ret = llama_encode_internal(*ctx, batch);
2115421161 if (ret != 0) {
2115521162 LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2115621163 }
@@ -21161,8 +21168,7 @@ int32_t llama_encode(
2116121168int32_t llama_decode(
2116221169 struct llama_context * ctx,
2116321170 struct llama_batch batch) {
21164- llama_batch_allocr batch_allocr(ctx, batch);
21165- const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
21171+ const int ret = llama_decode_internal(*ctx, batch);
2116621172 if (ret != 0) {
2116721173 LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2116821174 }
0 commit comments