Skip to content

Commit b3304da

Browse files
committed
SSM: respect ssm_dt_rank for dt_dim when provided
Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16).
1 parent ab53234 commit b3304da

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/llama-model.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3253,7 +3253,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32533253
const uint32_t v_dim = head_dim;
32543254
const int64_t num_attention_heads = hparams.n_head();
32553255
const int64_t q_num_heads = num_attention_heads;
3256-
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3256+
const int64_t dt_dim = hparams.ssm_dt_rank > 0
3257+
? hparams.ssm_dt_rank
3258+
: std::max<int64_t>(64, hparams.n_embd / 16);
32573259

32583260
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
32593261

@@ -17263,7 +17265,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1726317265
cb(x_bcdt, "mamba_bcdt_proj", il);
1726417266

1726517267
// split into dt, B, C
17266-
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
17268+
const int64_t dt_dim = hparams.ssm_dt_rank > 0
17269+
? hparams.ssm_dt_rank
17270+
: std::max<int64_t>(64, hparams.n_embd / 16);
1726717271
ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
1726817272
ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
1726917273
ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));

0 commit comments

Comments
 (0)