@@ -3253,7 +3253,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32533253 const uint32_t v_dim = head_dim;
32543254 const int64_t num_attention_heads = hparams.n_head();
32553255 const int64_t q_num_heads = num_attention_heads;
3256- const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3256+ const int64_t dt_dim = hparams.ssm_dt_rank > 0
3257+ ? hparams.ssm_dt_rank
3258+ : std::max<int64_t>(64, hparams.n_embd / 16);
32573259
32583260 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
32593261
@@ -17263,7 +17265,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1726317265 cb(x_bcdt, "mamba_bcdt_proj", il);
1726417266
1726517267 // split into dt, B, C
17266- const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
17268+ const int64_t dt_dim = hparams.ssm_dt_rank > 0
17269+ ? hparams.ssm_dt_rank
17270+ : std::max<int64_t>(64, hparams.n_embd / 16);
1726717271 ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
1726817272 ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
1726917273 ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));
0 commit comments