Skip to content

Commit 937966d

Browse files
committed
llama : fix Mamba inference for pipeline parallelism
Tested to work correctly with both `main` and `parallel` examples.
1 parent 4ddccc2 commit 937966d

File tree

1 file changed

+88
-59
lines changed

1 file changed

+88
-59
lines changed

llama.cpp

Lines changed: 88 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,7 +2082,7 @@ struct llama_context {
20822082
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
20832083
struct ggml_tensor * inp_cls; // I32 [n_batch]
20842084
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2085-
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2085+
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
20862086
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
20872087

20882088
#ifdef GGML_USE_MPI
@@ -5518,6 +5518,9 @@ struct llm_build_context {
55185518
lctx.inp_K_shift = nullptr;
55195519
lctx.inp_mean = nullptr;
55205520
lctx.inp_cls = nullptr;
5521+
lctx.inp_s_copy = nullptr;
5522+
lctx.inp_s_mask = nullptr;
5523+
lctx.inp_s_seq = nullptr;
55215524
}
55225525

55235526
void free() {
@@ -5559,14 +5562,14 @@ struct llm_build_context {
55595562

55605563
GGML_ASSERT(kv_self.recurrent);
55615564

5562-
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5565+
struct ggml_tensor * state_copy = build_inp_s_copy();
55635566

55645567
for (int il = 0; il < n_layer; ++il) {
55655568
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
55665569
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
55675570

5568-
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5569-
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5571+
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
5572+
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
55705573

55715574
// TODO: name the intermediate tensors with cb()
55725575

@@ -5665,6 +5668,27 @@ struct llm_build_context {
56655668
return lctx.inp_cls;
56665669
}
56675670

5671+
struct ggml_tensor * build_inp_s_copy() {
5672+
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5673+
cb(lctx.inp_s_copy, "inp_s_copy", -1);
5674+
ggml_set_input(lctx.inp_s_copy);
5675+
return lctx.inp_s_copy;
5676+
}
5677+
5678+
struct ggml_tensor * build_inp_s_mask() {
5679+
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
5680+
cb(lctx.inp_s_mask, "inp_s_mask", -1);
5681+
ggml_set_input(lctx.inp_s_mask);
5682+
return lctx.inp_s_mask;
5683+
}
5684+
5685+
struct ggml_tensor * build_inp_s_seq() {
5686+
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
5687+
cb(lctx.inp_s_seq, "inp_s_seq", -1);
5688+
ggml_set_input(lctx.inp_s_seq);
5689+
return lctx.inp_s_seq;
5690+
}
5691+
56685692
struct ggml_cgraph * build_llama() {
56695693
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
56705694

@@ -8148,12 +8172,8 @@ struct llm_build_context {
81488172
// {n_embd, n_tokens}
81498173
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
81508174

8151-
struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
8152-
struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
8153-
lctx.inp_s_mask = state_mask;
8154-
lctx.inp_s_seq = state_seq;
8155-
ggml_set_input(state_mask);
8156-
ggml_set_input(state_seq);
8175+
struct ggml_tensor * state_mask = build_inp_s_mask();
8176+
struct ggml_tensor * state_seq = build_inp_s_seq();
81578177

81588178
for (int il = 0; il < n_layer; ++il) {
81598179
// (ab)using the KV cache to store the states
@@ -8205,7 +8225,7 @@ struct llm_build_context {
82058225
ggml_build_forward_expand(gf,
82068226
ggml_cpy(ctx0,
82078227
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
8208-
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
8228+
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
82098229

82108230
// extract x from x_conv
82118231
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
@@ -8239,7 +8259,7 @@ struct llm_build_context {
82398259
ggml_build_forward_expand(gf,
82408260
ggml_cpy(ctx0,
82418261
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
8242-
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
8262+
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
82438263

82448264
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
82458265

@@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85088528
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
85098529
}
85108530

8511-
if (batch.pos) {
8531+
if (batch.pos && lctx.inp_pos) {
85128532
const int64_t n_tokens = batch.n_tokens;
85138533

85148534
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85198539
"non-causal attention with generative models is not supported"
85208540
);
85218541

8522-
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8523-
if (cparams.causal_attn) {
8524-
const int64_t n_kv = kv_self.n;
8525-
const int64_t n_tokens = batch.n_tokens;
8542+
if (lctx.inp_KQ_mask) {
8543+
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8544+
if (cparams.causal_attn) {
8545+
const int64_t n_kv = kv_self.n;
8546+
const int64_t n_tokens = batch.n_tokens;
85268547

8527-
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8548+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
85288549

8529-
float * data = (float *) lctx.inp_KQ_mask->data;
8550+
float * data = (float *) lctx.inp_KQ_mask->data;
85308551

8531-
// For causal attention, use only the previous KV cells
8532-
// of the correct sequence for each token of the batch.
8533-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8534-
for (int h = 0; h < 1; ++h) {
8535-
for (int j = 0; j < n_tokens; ++j) {
8536-
const llama_pos pos = batch.pos[j];
8537-
const llama_seq_id seq_id = batch.seq_id[j][0];
8552+
// For causal attention, use only the previous KV cells
8553+
// of the correct sequence for each token of the batch.
8554+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8555+
for (int h = 0; h < 1; ++h) {
8556+
for (int j = 0; j < n_tokens; ++j) {
8557+
const llama_pos pos = batch.pos[j];
8558+
const llama_seq_id seq_id = batch.seq_id[j][0];
85388559

8539-
for (int i = 0; i < n_kv; ++i) {
8540-
float f;
8541-
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8542-
f = -INFINITY;
8543-
} else {
8544-
f = 0.0f;
8560+
for (int i = 0; i < n_kv; ++i) {
8561+
float f;
8562+
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8563+
f = -INFINITY;
8564+
} else {
8565+
f = 0.0f;
8566+
}
8567+
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
85458568
}
8546-
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
85478569
}
85488570
}
8549-
}
8550-
} else {
8551-
// when using kv cache, the mask needs to match the kv cache size
8552-
const int64_t n_tokens = batch.n_tokens;
8553-
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8571+
} else {
8572+
// when using kv cache, the mask needs to match the kv cache size
8573+
const int64_t n_tokens = batch.n_tokens;
8574+
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
85548575

8555-
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8576+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
85568577

8557-
float * data = (float *) lctx.inp_KQ_mask->data;
8578+
float * data = (float *) lctx.inp_KQ_mask->data;
85588579

8559-
for (int h = 0; h < 1; ++h) {
8560-
for (int j = 0; j < n_tokens; ++j) {
8561-
const llama_seq_id seq_id = batch.seq_id[j][0];
8580+
for (int h = 0; h < 1; ++h) {
8581+
for (int j = 0; j < n_tokens; ++j) {
8582+
const llama_seq_id seq_id = batch.seq_id[j][0];
85628583

8563-
for (int i = 0; i < n_tokens; ++i) {
8564-
float f = -INFINITY;
8565-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8566-
if (batch.seq_id[i][s] == seq_id) {
8567-
f = 0.0f;
8568-
break;
8584+
for (int i = 0; i < n_tokens; ++i) {
8585+
float f = -INFINITY;
8586+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8587+
if (batch.seq_id[i][s] == seq_id) {
8588+
f = 0.0f;
8589+
break;
8590+
}
85698591
}
8570-
}
85718592

8572-
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8573-
}
8593+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8594+
}
85748595

8575-
for (int i = n_tokens; i < n_stride; ++i) {
8576-
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8596+
for (int i = n_tokens; i < n_stride; ++i) {
8597+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8598+
}
85778599
}
85788600
}
85798601
}
@@ -8582,7 +8604,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85828604
if (hparams.need_kq_pos) {
85838605
const int64_t n_kv = kv_self.n;
85848606

8585-
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
8607+
GGML_ASSERT(lctx.inp_KQ_pos);
8608+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
85868609

85878610
float * data = (float *) lctx.inp_KQ_pos->data;
85888611

@@ -8594,6 +8617,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85948617
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
85958618
const int64_t n_tokens = batch.n_tokens;
85968619

8620+
GGML_ASSERT(lctx.inp_mean);
85978621
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
85988622

85998623
float * data = (float *) lctx.inp_mean->data;
@@ -8625,6 +8649,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
86258649
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
86268650
const int64_t n_tokens = batch.n_tokens;
86278651

8652+
GGML_ASSERT(lctx.inp_cls);
86288653
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
86298654

86308655
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
@@ -8645,7 +8670,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
86458670
if (kv_self.recurrent) {
86468671
const int64_t n_kv = kv_self.n;
86478672

8648-
{
8673+
if (lctx.inp_s_mask) {
86498674
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
86508675
float * data = (float *) lctx.inp_s_mask->data;
86518676

@@ -8667,7 +8692,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
86678692
// update the correct state(s)/sequence(s) for each token of the batch.
86688693
// Like with the KQ_mask, if a token in the batch has multiple sequences,
86698694
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
8670-
{
8695+
if (lctx.inp_s_seq) {
86718696
const int64_t n_tokens = batch.n_tokens;
86728697

86738698
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
@@ -9272,11 +9297,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
92729297
}
92739298

92749299
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
9275-
llama_set_s_copy(lctx);
9276-
92779300
{
9301+
ggml_backend_sched_reset(lctx.sched);
9302+
92789303
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
92799304

9305+
ggml_backend_sched_alloc_graph(lctx.sched, gf);
9306+
9307+
llama_set_s_copy(lctx);
9308+
92809309
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
92819310

92829311
need_reserve = true;

0 commit comments

Comments
 (0)