Skip to content

Commit bbaff57

Browse files
ngxsonggerganov
authored andcommitted
llama : fix empty batch causing llama_batch_allocr to crash (ggml-org#9966)
* llama : fix empty batch cause llama_batch_allocr to crash * move batch_allocr inside decode/encode_internal * fix build * add GGML_ASSERT * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 8e01f76 commit bbaff57

File tree

1 file changed

+67
-61
lines changed

1 file changed

+67
-61
lines changed

src/llama.cpp

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5179,6 +5179,57 @@ struct llama_model_loader {
51795179
}
51805180
};
51815181

5182+
// temporary allocate memory for the input batch if needed
5183+
static const llama_seq_id batch_default_seq_id = 0;
5184+
struct llama_batch_allocr {
5185+
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
5186+
std::vector<llama_pos> pos;
5187+
std::vector<int32_t> n_seq_id;
5188+
std::vector<llama_seq_id *> seq_id;
5189+
std::vector<int8_t> logits;
5190+
struct llama_batch batch;
5191+
// optionally fulfill the batch returned by llama_batch_get_one
5192+
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
5193+
batch = in_batch;
5194+
GGML_ASSERT(batch.n_tokens > 0);
5195+
if (!batch.pos) {
5196+
// determine the last position in KV cache
5197+
llama_pos last_pos = -1;
5198+
for (const auto & cell : ctx.kv_self.cells) {
5199+
if (cell.has_seq_id(batch_default_seq_id)) {
5200+
last_pos = std::max(last_pos, cell.pos);
5201+
}
5202+
}
5203+
last_pos++; // next position
5204+
pos.resize(batch.n_tokens);
5205+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5206+
pos[i] = i+last_pos;
5207+
}
5208+
batch.pos = pos.data();
5209+
}
5210+
if (!batch.n_seq_id) {
5211+
n_seq_id.resize(batch.n_tokens);
5212+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5213+
n_seq_id[i] = seq_id_0.size();
5214+
}
5215+
batch.n_seq_id = n_seq_id.data();
5216+
}
5217+
if (!batch.seq_id) {
5218+
seq_id.resize(batch.n_tokens + 1);
5219+
seq_id[batch.n_tokens] = NULL;
5220+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5221+
seq_id[i] = seq_id_0.data();
5222+
}
5223+
batch.seq_id = seq_id.data();
5224+
}
5225+
if (!batch.logits) {
5226+
logits.resize(batch.n_tokens);
5227+
logits[logits.size() - 1] = true;
5228+
batch.logits = logits.data();
5229+
}
5230+
}
5231+
};
5232+
51825233
template<>
51835234
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
51845235
uint32_t tmp;
@@ -17101,16 +17152,20 @@ static void llama_graph_compute(
1710117152
//
1710217153
static int llama_decode_internal(
1710317154
llama_context & lctx,
17104-
llama_batch batch) {
17155+
llama_batch inp_batch) {
1710517156

1710617157
lctx.is_encoding = false;
17107-
const uint32_t n_tokens_all = batch.n_tokens;
1710817158

17109-
if (n_tokens_all == 0) {
17159+
if (inp_batch.n_tokens == 0) {
1711017160
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1711117161
return -1;
1711217162
}
1711317163

17164+
// temporary allocate memory for the input batch if needed
17165+
llama_batch_allocr batch_allocr(lctx, inp_batch);
17166+
const llama_batch & batch = batch_allocr.batch;
17167+
const uint32_t n_tokens_all = batch.n_tokens;
17168+
1711417169
const auto & model = lctx.model;
1711517170
const auto & hparams = model.hparams;
1711617171
const auto & cparams = lctx.cparams;
@@ -17415,17 +17470,20 @@ static int llama_decode_internal(
1741517470
//
1741617471
static int llama_encode_internal(
1741717472
llama_context & lctx,
17418-
llama_batch batch) {
17473+
llama_batch inp_batch) {
1741917474

1742017475
lctx.is_encoding = true;
1742117476

17422-
const uint32_t n_tokens = batch.n_tokens;
17423-
17424-
if (n_tokens == 0) {
17477+
if (inp_batch.n_tokens == 0) {
1742517478
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1742617479
return -1;
1742717480
}
1742817481

17482+
// temporary allocate memory for the input batch if needed
17483+
llama_batch_allocr batch_allocr(lctx, inp_batch);
17484+
const llama_batch & batch = batch_allocr.batch;
17485+
const uint32_t n_tokens = batch.n_tokens;
17486+
1742917487
const auto & model = lctx.model;
1743017488
const auto & hparams = model.hparams;
1743117489
const auto & cparams = lctx.cparams;
@@ -21096,61 +21154,10 @@ void llama_batch_free(struct llama_batch batch) {
2109621154
if (batch.logits) free(batch.logits);
2109721155
}
2109821156

21099-
// temporary allocate memory for the input batch if needed
21100-
static const llama_seq_id batch_default_seq_id = 0;
21101-
struct llama_batch_allocr {
21102-
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
21103-
std::vector<llama_pos> pos;
21104-
std::vector<int32_t> n_seq_id;
21105-
std::vector<llama_seq_id *> seq_id;
21106-
std::vector<int8_t> logits;
21107-
struct llama_batch batch;
21108-
// optionally fulfill the batch returned by llama_batch_get_one
21109-
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
21110-
batch = in_batch;
21111-
if (!batch.pos) {
21112-
// determine the last position in KV cache
21113-
llama_pos last_pos = -1;
21114-
for (const auto & cell : ctx->kv_self.cells) {
21115-
if (cell.has_seq_id(batch_default_seq_id)) {
21116-
last_pos = std::max(last_pos, cell.pos);
21117-
}
21118-
}
21119-
last_pos++; // next position
21120-
pos.resize(batch.n_tokens);
21121-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21122-
pos[i] = i+last_pos;
21123-
}
21124-
batch.pos = pos.data();
21125-
}
21126-
if (!batch.n_seq_id) {
21127-
n_seq_id.resize(batch.n_tokens);
21128-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21129-
n_seq_id[i] = seq_id_0.size();
21130-
}
21131-
batch.n_seq_id = n_seq_id.data();
21132-
}
21133-
if (!batch.seq_id) {
21134-
seq_id.resize(batch.n_tokens + 1);
21135-
seq_id[batch.n_tokens] = NULL;
21136-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21137-
seq_id[i] = seq_id_0.data();
21138-
}
21139-
batch.seq_id = seq_id.data();
21140-
}
21141-
if (!batch.logits) {
21142-
logits.resize(batch.n_tokens);
21143-
logits[logits.size() - 1] = true;
21144-
batch.logits = logits.data();
21145-
}
21146-
}
21147-
};
21148-
2114921157
int32_t llama_encode(
2115021158
struct llama_context * ctx,
2115121159
struct llama_batch batch) {
21152-
llama_batch_allocr batch_allocr(ctx, batch);
21153-
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
21160+
const int ret = llama_encode_internal(*ctx, batch);
2115421161
if (ret != 0) {
2115521162
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2115621163
}
@@ -21161,8 +21168,7 @@ int32_t llama_encode(
2116121168
int32_t llama_decode(
2116221169
struct llama_context * ctx,
2116321170
struct llama_batch batch) {
21164-
llama_batch_allocr batch_allocr(ctx, batch);
21165-
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
21171+
const int ret = llama_decode_internal(*ctx, batch);
2116621172
if (ret != 0) {
2116721173
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2116821174
}

0 commit comments

Comments
 (0)