@@ -5540,9 +5540,11 @@ struct llm_build_context {
5540
5540
struct ggml_cgraph * build_s_copy() {
5541
5541
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5542
5542
5543
+ GGML_ASSERT(kv_self.recurrent);
5544
+
5543
5545
for (int il = 0; il < n_layer; ++il) {
5544
- ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5545
- ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5546
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5547
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5546
5548
5547
5549
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5548
5550
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
@@ -8171,14 +8173,16 @@ struct llm_build_context {
8171
8173
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
8172
8174
cb(inpL, "inp_embd", -1);
8173
8175
8176
+ struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
8177
+ struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
8178
+
8174
8179
for (int il = 0; il < n_layer; ++il) {
8175
8180
// (ab)using the KV cache to store the states
8176
- ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8177
- ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
8181
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8182
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
8178
8183
8179
8184
// clear states of sequences which are starting at the beginning of this batch
8180
8185
{
8181
- ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
8182
8186
conv_states = ggml_mul(ctx0,
8183
8187
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
8184
8188
state_mask);
@@ -8203,8 +8207,6 @@ struct llm_build_context {
8203
8207
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
8204
8208
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
8205
8209
8206
- struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
8207
-
8208
8210
// conv
8209
8211
{
8210
8212
// Custom operator which is needed only to ease simultaneous sequence processing.
0 commit comments