@@ -2082,7 +2082,7 @@ struct llama_context {
2082
2082
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2083
2083
struct ggml_tensor * inp_cls; // I32 [n_batch]
2084
2084
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2085
- struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2085
+ struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
2086
2086
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2087
2087
2088
2088
#ifdef GGML_USE_MPI
@@ -5518,6 +5518,9 @@ struct llm_build_context {
5518
5518
lctx.inp_K_shift = nullptr;
5519
5519
lctx.inp_mean = nullptr;
5520
5520
lctx.inp_cls = nullptr;
5521
+ lctx.inp_s_copy = nullptr;
5522
+ lctx.inp_s_mask = nullptr;
5523
+ lctx.inp_s_seq = nullptr;
5521
5524
}
5522
5525
5523
5526
void free() {
@@ -5559,14 +5562,14 @@ struct llm_build_context {
5559
5562
5560
5563
GGML_ASSERT(kv_self.recurrent);
5561
5564
5562
- lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size );
5565
+ struct ggml_tensor * state_copy = build_inp_s_copy( );
5563
5566
5564
5567
for (int il = 0; il < n_layer; ++il) {
5565
5568
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5566
5569
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5567
5570
5568
- conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy );
5569
- ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy );
5571
+ conv_states = ggml_get_rows(ctx0, conv_states, state_copy );
5572
+ ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy );
5570
5573
5571
5574
// TODO: name the intermediate tensors with cb()
5572
5575
@@ -5665,6 +5668,27 @@ struct llm_build_context {
5665
5668
return lctx.inp_cls;
5666
5669
}
5667
5670
5671
+ struct ggml_tensor * build_inp_s_copy() {
5672
+ lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5673
+ cb(lctx.inp_s_copy, "inp_s_copy", -1);
5674
+ ggml_set_input(lctx.inp_s_copy);
5675
+ return lctx.inp_s_copy;
5676
+ }
5677
+
5678
+ struct ggml_tensor * build_inp_s_mask() {
5679
+ lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
5680
+ cb(lctx.inp_s_mask, "inp_s_mask", -1);
5681
+ ggml_set_input(lctx.inp_s_mask);
5682
+ return lctx.inp_s_mask;
5683
+ }
5684
+
5685
+ struct ggml_tensor * build_inp_s_seq() {
5686
+ lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
5687
+ cb(lctx.inp_s_seq, "inp_s_seq", -1);
5688
+ ggml_set_input(lctx.inp_s_seq);
5689
+ return lctx.inp_s_seq;
5690
+ }
5691
+
5668
5692
struct ggml_cgraph * build_llama() {
5669
5693
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5670
5694
@@ -8148,12 +8172,8 @@ struct llm_build_context {
8148
8172
// {n_embd, n_tokens}
8149
8173
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
8150
8174
8151
- struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
8152
- struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
8153
- lctx.inp_s_mask = state_mask;
8154
- lctx.inp_s_seq = state_seq;
8155
- ggml_set_input(state_mask);
8156
- ggml_set_input(state_seq);
8175
+ struct ggml_tensor * state_mask = build_inp_s_mask();
8176
+ struct ggml_tensor * state_seq = build_inp_s_seq();
8157
8177
8158
8178
for (int il = 0; il < n_layer; ++il) {
8159
8179
// (ab)using the KV cache to store the states
@@ -8205,7 +8225,7 @@ struct llm_build_context {
8205
8225
ggml_build_forward_expand(gf,
8206
8226
ggml_cpy(ctx0,
8207
8227
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
8208
- ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head *(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
8228
+ ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head *(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
8209
8229
8210
8230
// extract x from x_conv
8211
8231
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
@@ -8239,7 +8259,7 @@ struct llm_build_context {
8239
8259
ggml_build_forward_expand(gf,
8240
8260
ggml_cpy(ctx0,
8241
8261
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
8242
- ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head *d_state*d_inner*ggml_element_size(ssm_states))));
8262
+ ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head *d_state*d_inner*ggml_element_size(ssm_states))));
8243
8263
8244
8264
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
8245
8265
@@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8508
8528
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
8509
8529
}
8510
8530
8511
- if (batch.pos) {
8531
+ if (batch.pos && lctx.inp_pos ) {
8512
8532
const int64_t n_tokens = batch.n_tokens;
8513
8533
8514
8534
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8519
8539
"non-causal attention with generative models is not supported"
8520
8540
);
8521
8541
8522
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8523
- if (cparams.causal_attn) {
8524
- const int64_t n_kv = kv_self.n;
8525
- const int64_t n_tokens = batch.n_tokens;
8542
+ if (lctx.inp_KQ_mask) {
8543
+ // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8544
+ if (cparams.causal_attn) {
8545
+ const int64_t n_kv = kv_self.n;
8546
+ const int64_t n_tokens = batch.n_tokens;
8526
8547
8527
- assert (ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8548
+ GGML_ASSERT (ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8528
8549
8529
- float * data = (float *) lctx.inp_KQ_mask->data;
8550
+ float * data = (float *) lctx.inp_KQ_mask->data;
8530
8551
8531
- // For causal attention, use only the previous KV cells
8532
- // of the correct sequence for each token of the batch.
8533
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8534
- for (int h = 0; h < 1; ++h) {
8535
- for (int j = 0; j < n_tokens; ++j) {
8536
- const llama_pos pos = batch.pos[j];
8537
- const llama_seq_id seq_id = batch.seq_id[j][0];
8552
+ // For causal attention, use only the previous KV cells
8553
+ // of the correct sequence for each token of the batch.
8554
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8555
+ for (int h = 0; h < 1; ++h) {
8556
+ for (int j = 0; j < n_tokens; ++j) {
8557
+ const llama_pos pos = batch.pos[j];
8558
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8538
8559
8539
- for (int i = 0; i < n_kv; ++i) {
8540
- float f;
8541
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8542
- f = -INFINITY;
8543
- } else {
8544
- f = 0.0f;
8560
+ for (int i = 0; i < n_kv; ++i) {
8561
+ float f;
8562
+ if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8563
+ f = -INFINITY;
8564
+ } else {
8565
+ f = 0.0f;
8566
+ }
8567
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8545
8568
}
8546
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8547
8569
}
8548
8570
}
8549
- }
8550
- } else {
8551
- // when using kv cache, the mask needs to match the kv cache size
8552
- const int64_t n_tokens = batch.n_tokens;
8553
- const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8571
+ } else {
8572
+ // when using kv cache, the mask needs to match the kv cache size
8573
+ const int64_t n_tokens = batch.n_tokens;
8574
+ const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8554
8575
8555
- assert (ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8576
+ GGML_ASSERT (ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8556
8577
8557
- float * data = (float *) lctx.inp_KQ_mask->data;
8578
+ float * data = (float *) lctx.inp_KQ_mask->data;
8558
8579
8559
- for (int h = 0; h < 1; ++h) {
8560
- for (int j = 0; j < n_tokens; ++j) {
8561
- const llama_seq_id seq_id = batch.seq_id[j][0];
8580
+ for (int h = 0; h < 1; ++h) {
8581
+ for (int j = 0; j < n_tokens; ++j) {
8582
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8562
8583
8563
- for (int i = 0; i < n_tokens; ++i) {
8564
- float f = -INFINITY;
8565
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8566
- if (batch.seq_id[i][s] == seq_id) {
8567
- f = 0.0f;
8568
- break;
8584
+ for (int i = 0; i < n_tokens; ++i) {
8585
+ float f = -INFINITY;
8586
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8587
+ if (batch.seq_id[i][s] == seq_id) {
8588
+ f = 0.0f;
8589
+ break;
8590
+ }
8569
8591
}
8570
- }
8571
8592
8572
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8573
- }
8593
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8594
+ }
8574
8595
8575
- for (int i = n_tokens; i < n_stride; ++i) {
8576
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8596
+ for (int i = n_tokens; i < n_stride; ++i) {
8597
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8598
+ }
8577
8599
}
8578
8600
}
8579
8601
}
@@ -8582,7 +8604,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8582
8604
if (hparams.need_kq_pos) {
8583
8605
const int64_t n_kv = kv_self.n;
8584
8606
8585
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
8607
+ GGML_ASSERT(lctx.inp_KQ_pos);
8608
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
8586
8609
8587
8610
float * data = (float *) lctx.inp_KQ_pos->data;
8588
8611
@@ -8594,6 +8617,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8594
8617
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8595
8618
const int64_t n_tokens = batch.n_tokens;
8596
8619
8620
+ GGML_ASSERT(lctx.inp_mean);
8597
8621
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
8598
8622
8599
8623
float * data = (float *) lctx.inp_mean->data;
@@ -8625,6 +8649,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8625
8649
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
8626
8650
const int64_t n_tokens = batch.n_tokens;
8627
8651
8652
+ GGML_ASSERT(lctx.inp_cls);
8628
8653
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
8629
8654
8630
8655
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
@@ -8645,7 +8670,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8645
8670
if (kv_self.recurrent) {
8646
8671
const int64_t n_kv = kv_self.n;
8647
8672
8648
- {
8673
+ if (lctx.inp_s_mask) {
8649
8674
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
8650
8675
float * data = (float *) lctx.inp_s_mask->data;
8651
8676
@@ -8667,7 +8692,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8667
8692
// update the correct state(s)/sequence(s) for each token of the batch.
8668
8693
// Like with the KQ_mask, if a token in the batch has multiple sequences,
8669
8694
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
8670
- {
8695
+ if (lctx.inp_s_seq) {
8671
8696
const int64_t n_tokens = batch.n_tokens;
8672
8697
8673
8698
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
@@ -9272,11 +9297,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
9272
9297
}
9273
9298
9274
9299
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
9275
- llama_set_s_copy(lctx);
9276
-
9277
9300
{
9301
+ ggml_backend_sched_reset(lctx.sched);
9302
+
9278
9303
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
9279
9304
9305
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
9306
+
9307
+ llama_set_s_copy(lctx);
9308
+
9280
9309
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9281
9310
9282
9311
need_reserve = true;
0 commit comments