Skip to content

Commit 39579d3

Browse files
committed
mamba : move state_seq and state_mask views outside layer loop
A few tensors were also missing `struct` in front of `ggml_tensor`.
1 parent 3e5685f commit 39579d3

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

llama.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5540,9 +5540,11 @@ struct llm_build_context {
55405540
struct ggml_cgraph * build_s_copy() {
55415541
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
55425542

5543+
GGML_ASSERT(kv_self.recurrent);
5544+
55435545
for (int il = 0; il < n_layer; ++il) {
5544-
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5545-
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5546+
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5547+
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
55465548

55475549
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
55485550
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
@@ -8171,14 +8173,16 @@ struct llm_build_context {
81718173
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
81728174
cb(inpL, "inp_embd", -1);
81738175

8176+
struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
8177+
struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
8178+
81748179
for (int il = 0; il < n_layer; ++il) {
81758180
// (ab)using the KV cache to store the states
8176-
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8177-
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
8181+
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8182+
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
81788183

81798184
// clear states of sequences which are starting at the beginning of this batch
81808185
{
8181-
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
81828186
conv_states = ggml_mul(ctx0,
81838187
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
81848188
state_mask);
@@ -8203,8 +8207,6 @@ struct llm_build_context {
82038207
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
82048208
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
82058209

8206-
struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
8207-
82088210
// conv
82098211
{
82108212
// Custom operator which is needed only to ease simultaneous sequence processing.

0 commit comments

Comments
 (0)