diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8f462396f4a..a1d4e6e5197 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -247,6 +247,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, + { LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index b47c05d90d5..3b80b2ae19f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -251,6 +251,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ATTENTION_SHARED_KV_LAYERS, + LLM_KV_ATTENTION_RECURRENT_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 2239309c8fb..087afec55c6 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -8,18 +8,31 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0); } } else { for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); } } } +// TODO: implement +//void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) { +// if (dense_first) { +// for (uint32_t il = 0; il < n_layer; ++il) { +// is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0); +// } +// } else { +// for (uint32_t il = 0; il < n_layer; ++il) { +// is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); +// } +// } +//} + bool llama_hparams::is_swa_any() const { for (uint32_t il = 0; il < n_layer; ++il) { - if (swa_layers[il]) { + if (is_swa_impl[il]) { return true; } } @@ -193,9 +206,9 @@ uint32_t llama_hparams::n_embd_s() const { return ssm_d_state * ssm_d_inner; } -bool llama_hparams::is_recurrent(uint32_t il) const { +bool llama_hparams::is_recr(uint32_t il) const { if (il < n_layer) { - return recurrent_layer_arr[il]; + return is_recr_impl[il]; } GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); @@ -207,7 +220,7 @@ uint32_t llama_hparams::n_pos_per_embd() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return swa_layers[il]; + return is_swa_impl[il]; } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index e4601d30f51..e8ed4dd74de 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -37,6 +37,9 @@ struct llama_hparams_convnext { }; struct llama_hparams { + // note: use the `_impl` suffix to avoid name conflict between members and getters + // for example: n_embd_out() vs n_embd_out_impl + bool vocab_only; bool no_alloc; bool rope_finetuned; @@ -46,7 +49,7 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; @@ -137,11 +140,15 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == 1, then layer il is SWA - // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) + + // if is_swa_impl[il] == 1, then layer il is SWA + // if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense // note: using uint32_t type for compatibility reason - std::array swa_layers; + std::array is_swa_impl; + + // for hybrid state space models + std::array is_recr_impl; // for State Space Models uint32_t ssm_d_conv = 0; @@ -153,9 +160,6 @@ struct llama_hparams { // for Kimi Linear KDA uint32_t n_embd_head_kda = 0; - // for hybrid state space models - std::array recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -266,6 +270,14 @@ struct llama_hparams { // return true if one of the layers is SWA bool is_swa_any() const; + bool is_swa(uint32_t il) const; + + // TODO: implement + //void set_recr_pattern(uint32_t n_pattern, bool dense_first = false); + + // whether or not the given layer is recurrent (for hybrid models) + bool is_recr(uint32_t il) const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; @@ -307,13 +319,8 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_s() const; - // whether or not the given layer is recurrent (for hybrid models) - bool is_recurrent(uint32_t il) const; - uint32_t n_pos_per_embd() const; - bool is_swa(uint32_t il) const; - // note: currently only support if either all or none of the layers are MLA bool is_mla() const; diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 72f5c2fea72..a242079b406 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -44,7 +44,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_ubatch, n_pad, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, nullptr )), @@ -57,7 +57,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 6bd2ec18ce3..66ec3fd6d55 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -45,7 +45,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_swa, swa_type, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, nullptr )), @@ -58,7 +58,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index c645d0785ab..4d7b11067c9 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -146,7 +146,7 @@ namespace GGUFMeta { const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { arr_type, - size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_n(ctx, k), arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } @@ -445,7 +445,7 @@ namespace GGUFMeta { } if (n > N_MAX) { - throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str())); } if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { @@ -502,9 +502,9 @@ namespace GGUFMeta { } // TODO: this is not very clever - figure out something better - template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr> (enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); - template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 539d17eebc6..26fda1abfae 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -14,9 +14,6 @@ bool llama_model_saver_supports_arch(llm_arch arch) { switch (arch) { - case LLM_ARCH_QWEN3NEXT: - case LLM_ARCH_QWEN35: - case LLM_ARCH_QWEN35MOE: case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: @@ -107,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); } else if (std::is_same::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values); } else if (std::is_same::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); } else if (std::is_same::value) { @@ -245,7 +244,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -279,6 +278,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bd5635ed456..3c2a8e78b78 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -373,10 +373,10 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // count only the same type of previous layers to avoid this auto get_il_eff = [&](const size_t il){ size_t ret = 0; - const bool il_is_recurrent = hparams.is_recurrent(il); - const bool il_is_swa = hparams.is_swa(il); + const bool il_is_recr = hparams.is_recr(il); + const bool il_is_swa = hparams.is_swa(il); for (size_t il_prev = 0; il_prev < il; il_prev++) { - ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa; } return ret; }; @@ -553,7 +553,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str }; auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector> & segments) -> std::vector { - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // linear attention const int64_t head_dim = hparams.ssm_d_state; const int64_t granularity_qkv = std::lcm(blck_size, head_dim); @@ -1076,18 +1076,16 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill( - hparams.recurrent_layer_arr.begin(), - hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0); std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); @@ -2040,18 +2038,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t) { return true; }; } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + return !hparams.is_recr(il) && hparams.n_ff(il) == 0; }; filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + return hparams.is_recr(il) && hparams.n_ff(il) == 0; }; } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; filter_attn = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && !hparams.is_recurrent(il); + return (uint32_t)il < n_main && !hparams.is_recr(il); }; filter_recr = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && hparams.is_recurrent(il); + return (uint32_t)il < n_main && hparams.is_recr(il); }; } diff --git a/src/models/falcon-h1.cpp b/src/models/falcon-h1.cpp index 94b65a3c7c9..c130ccdd49e 100644 --- a/src/models/falcon-h1.cpp +++ b/src/models/falcon-h1.cpp @@ -11,7 +11,7 @@ void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true); switch (hparams.n_layer) { case 36: diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 4f9d8b18bc7..de3f790c71c 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -2,7 +2,7 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer); uint32_t n_kv_shared_layers = 0; ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index 27f6706ea10..8740d9fc7d9 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -20,7 +20,7 @@ void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { // A layer is recurrent IFF the n_head_kv value is set to 0 for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -71,7 +71,7 @@ void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -158,7 +158,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { diff --git a/src/models/jamba.cpp b/src/models/jamba.cpp index 84ea63c3136..a62b121b3ee 100644 --- a/src/models/jamba.cpp +++ b/src/models/jamba.cpp @@ -9,7 +9,7 @@ void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } switch (hparams.n_layer) { diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index ecffb105496..c13f71b5bcb 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -15,7 +15,7 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent } // MoE parameters - Kimi uses moe_intermediate_size = 1024 @@ -53,7 +53,7 @@ void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); @@ -285,7 +285,7 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_build_forward_expand(gf, cur); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py const auto * mctx_cur = inp_rs->mctx; diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index 29081344b24..3898b56bb12 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -6,7 +6,7 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { @@ -19,7 +19,7 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + hparams.is_swa_impl[il] = !hparams.is_recr_impl[il]; } } } @@ -59,7 +59,7 @@ void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); @@ -235,8 +235,8 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); - cur = hparams.is_recurrent(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : - build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); + cur = hparams.is_recr(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/src/models/lfm2moe.cpp b/src/models/lfm2moe.cpp index 12a66c05c7d..81ced2eaba2 100644 --- a/src/models/lfm2moe.cpp +++ b/src/models/lfm2moe.cpp @@ -10,7 +10,7 @@ void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } switch (hparams.n_layer) { @@ -55,7 +55,7 @@ void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); diff --git a/src/models/llama4.cpp b/src/models/llama4.cpp index 0ff5376d571..8f39b3f59a5 100644 --- a/src/models/llama4.cpp +++ b/src/models/llama4.cpp @@ -15,7 +15,8 @@ void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); diff --git a/src/models/mellum.cpp b/src/models/mellum.cpp index a2372399bbd..1e1e97e9fa0 100644 --- a/src/models/mellum.cpp +++ b/src/models/mellum.cpp @@ -13,7 +13,7 @@ void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { if (res) { hparams.set_swa_pattern(swa_period); } else { - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer); } hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; diff --git a/src/models/mimo2.cpp b/src/models/mimo2.cpp index d0295ec116f..1bcdf696f2e 100644 --- a/src/models/mimo2.cpp +++ b/src/models/mimo2.cpp @@ -8,7 +8,8 @@ void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer); float value_scale = 0.0f; if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index a82f9c170b4..d2c811d2497 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -10,7 +10,7 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { // A layer is recurrent IFF the n_head_kv value is set to 0 and // the n_ff value is set to 0 for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + hparams.is_recr_impl[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -62,7 +62,7 @@ void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { // all blocks use the attn norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -143,7 +143,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else if (hparams.n_ff(il) == 0) { diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index b713889fe72..2ffa0898f71 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -12,7 +12,7 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } switch (hparams.n_layer) { @@ -54,7 +54,7 @@ void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); + bool is_mamba_layer = hparams.is_recr(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -128,7 +128,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - const bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recr(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index ba63ae441df..f8fd5369623 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -18,12 +18,13 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer, false)) { const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + hparams.is_recr_impl[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } @@ -69,7 +70,7 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -168,7 +169,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 4f87d55d911..8db0b4717d9 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -21,12 +21,13 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer, false)) { const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + hparams.is_recr_impl[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } @@ -75,7 +76,7 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -191,7 +192,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 1d873427db5..9e09ae6f232 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -14,11 +14,11 @@ void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.is_recr_impl[i] = ((i + 1) % full_attn_interval != 0); } } @@ -68,7 +68,7 @@ void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { // Attention layers create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -129,7 +129,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/src/models/step35.cpp b/src/models/step35.cpp index caf18c743ff..49e7b3b2a03 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -22,7 +22,9 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false);