From cc0d6c28d6627d855dcec36878981d9d63e9c5d3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 9 Feb 2026 15:38:55 +0100 Subject: [PATCH 01/10] model: support GLM MoE DSA arch --- convert_hf_to_gguf.py | 13 ++++ gguf-py/gguf/constants.py | 33 ++++++++++ src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-model.cpp | 133 +++++++++++++++++++++++++++++++++++++- 5 files changed, 180 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 843c00a8969..6927a57cd51 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8698,6 +8698,19 @@ def set_vocab(self): special_vocab.add_to_gguf(self.gguf_writer) +@ModelBase.register("GlmMoeDsaForCausalLM") +class GlmMoeDsaModel(DeepseekV2Model, Glm4MoeModel): + model_arch = gguf.MODEL_ARCH.GLM_DSA + + def set_gguf_parameters(self): + # combine DeepseekV2Model + GLM4MoeModel parameters + super().set_gguf_parameters() + + def modify_tensors(self, data_torch, name, bid): + # note: skip Glm4MoeModel super method + return super(DeepseekV2Model).modify_tensors(data_torch, name, bid) + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3af4fffe957..7aea1678d62 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -422,6 +422,7 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() + GLM_DSA = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -852,6 +853,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", + MODEL_ARCH.GLM_DSA: "glm-dsa", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -2615,6 +2617,37 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.GLM_DSA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bd78f1e5562..3634a156927 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -72,6 +72,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index e8263369b80..5997de9960b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -76,6 +76,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 674d06c8910..5f6e4a49eb1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1820,6 +1820,44 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + // TODO + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5430,6 +5468,97 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7576,7 +7705,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -8149,6 +8278,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: { llm = std::make_unique(*this, params); } break; @@ -8542,6 +8672,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From a44a3dbc4b7e869aca11ce243eaa31da90f8144e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 9 Feb 2026 19:03:46 +0100 Subject: [PATCH 02/10] working version --- convert_hf_to_gguf.py | 67 ++++++++++++++++++++++----------------- gguf-py/gguf/constants.py | 18 +++++++---- src/llama-arch.cpp | 36 +++++++++++++++++++++ src/llama-model.cpp | 49 +++++++++++++++------------- src/models/deepseek2.cpp | 5 +-- 5 files changed, 116 insertions(+), 59 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6927a57cd51..956be8049db 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7585,6 +7585,9 @@ def prepare_tensors(self): class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + # TODO @ngxson : remove this when we support MTP for deepseek models + skip_mtp = True + def set_vocab(self): try: self._set_vocab_gpt2() @@ -7716,10 +7719,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("e_score_correction_bias", "e_score_correction.bias") # skip Multi-Token Prediction (MTP) layers - block_count = self.hparams["num_hidden_layers"] - match = re.match(r"model.layers.(\d+)", name) - if match and int(match.group(1)) >= block_count: - return + if self.skip_mtp: + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return # process the experts separately if name.find("mlp.experts") != -1: @@ -8558,7 +8562,9 @@ def __init__(self, *args, **kwargs): self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) - def set_vocab(self): + # using staticmethod here to allow re-using it in other classes + @staticmethod + def set_vocab_glm(self: TextModel): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model) @@ -8568,7 +8574,6 @@ def set_vocab(self): self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - # Special tokens # Note: Using <|endoftext|> (151329) for eot causes endless generation special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 @@ -8578,6 +8583,9 @@ def set_vocab(self): special_vocab.add_to_gguf(self.gguf_writer) + def set_vocab(self): + Glm4MoeModel.set_vocab_glm(self) + def set_gguf_parameters(self): super().set_gguf_parameters() if (rope_dim := self.hparams.get("head_dim")) is None: @@ -8676,39 +8684,40 @@ def prepare_tensors(self): class Glm4MoeLiteModel(DeepseekV2Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 - # copied from Glm4MoeModel def set_vocab(self): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) - - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - - special_vocab.add_to_gguf(self.gguf_writer) + Glm4MoeModel.set_vocab_glm(self) @ModelBase.register("GlmMoeDsaForCausalLM") -class GlmMoeDsaModel(DeepseekV2Model, Glm4MoeModel): +class GlmMoeDsaModel(DeepseekV2Model): model_arch = gguf.MODEL_ARCH.GLM_DSA + skip_mtp = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + Glm4MoeModel.set_vocab_glm(self) def set_gguf_parameters(self): - # combine DeepseekV2Model + GLM4MoeModel parameters super().set_gguf_parameters() + rope_dim = self.hparams["qk_rope_head_dim"] + partial_rotary_factor = self.hparams["partial_rotary_factor"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + def modify_tensors(self, data_torch, name, bid): - # note: skip Glm4MoeModel super method - return super(DeepseekV2Model).modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7aea1678d62..4a6a9366556 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2621,18 +2621,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.ATTN_Q_NORM, - MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, - MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 3634a156927..1398a31db49 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1598,6 +1598,42 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5f6e4a49eb1..5188bec97e6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1835,6 +1835,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); @@ -5499,32 +5500,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); if (n_expert == 0) { @@ -5535,18 +5541,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - int flags = 0; if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 987f449934c..b2c1f160601 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -45,7 +45,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < effective_n_layers; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -188,7 +189,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == n_layer - 1 && inp_out_ids) { + if (il == effective_n_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } From 0451c849eef532d604994316d3398febf871fcd2 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 9 Feb 2026 19:09:00 +0100 Subject: [PATCH 03/10] pyright --- convert_hf_to_gguf.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 956be8049db..fc95e7ae197 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8564,16 +8564,16 @@ def __init__(self, *args, **kwargs): # using staticmethod here to allow re-using it in other classes @staticmethod - def set_vocab_glm(self: TextModel): + def set_vocab_glm(model: TextModel): from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) + tokenizer = AutoTokenizer.from_pretrained(model.dir_model) + special_vocab = gguf.SpecialVocab(model.dir_model, load_merges=True) + tokens, toktypes, tokpre = model.get_vocab_base() + model.gguf_writer.add_tokenizer_model("gpt2") + model.gguf_writer.add_tokenizer_pre(tokpre) + model.gguf_writer.add_token_list(tokens) + model.gguf_writer.add_token_types(toktypes) # Special tokens # Note: Using <|endoftext|> (151329) for eot causes endless generation special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 @@ -8581,7 +8581,7 @@ def set_vocab_glm(self: TextModel): special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - special_vocab.add_to_gguf(self.gguf_writer) + special_vocab.add_to_gguf(model.gguf_writer) def set_vocab(self): Glm4MoeModel.set_vocab_glm(self) From 9e4e556cc03256bd859d25af1c2502abf0b99c6c Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 12 Feb 2026 00:52:52 +0100 Subject: [PATCH 04/10] keep indexer tensors --- convert_hf_to_gguf.py | 2 +- gguf-py/gguf/constants.py | 12 ++++++++++++ gguf-py/gguf/tensor_mapping.py | 16 ++++++++++++++++ src/llama-arch.cpp | 8 ++++++++ src/llama-arch.h | 4 ++++ src/llama-model.cpp | 7 +++++++ src/llama-model.h | 7 +++++++ 7 files changed, 55 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fc95e7ae197..67cfcf9a376 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8706,7 +8706,7 @@ def set_gguf_parameters(self): super().set_gguf_parameters() rope_dim = self.hparams["qk_rope_head_dim"] - partial_rotary_factor = self.hparams["partial_rotary_factor"] + partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) # Expert gating function (sigmoid for GLM4_MOE) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4a6a9366556..05e131ac30c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -667,6 +667,10 @@ class MODEL_TENSOR(IntEnum): VISEXP_GATE = auto() VISEXP_DOWN = auto() VISEXP_UP = auto() + INDEXER_K_NORM = auto() + INDEXER_PROJ = auto() + INDEXER_ATTN_K = auto() + INDEXER_ATTN_Q_B = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -1096,6 +1100,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate", MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down", MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up", + MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm", + MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj", + MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k", + MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b", # vision MODEL_TENSOR.V_MMPROJ: "mm.{bid}", MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", @@ -2646,6 +2654,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.INDEXER_K_NORM, + MODEL_TENSOR.INDEXER_PROJ, + MODEL_TENSOR.INDEXER_ATTN_K, + MODEL_TENSOR.INDEXER_ATTN_Q_B, # NextN/MTP tensors - preserved but unused MODEL_TENSOR.NEXTN_EH_PROJ, MODEL_TENSOR.NEXTN_EMBED_TOKENS, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 167ade78033..0c944d77a04 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1199,6 +1199,22 @@ class TensorNameMap: "model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm ), + MODEL_TENSOR.INDEXER_K_NORM: ( + "model.layers.{bid}.self_attn.indexer.k_norm", # DSA + ), + + MODEL_TENSOR.INDEXER_PROJ: ( + "model.layers.{bid}.self_attn.indexer.weights_proj", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_K: ( + "model.layers.{bid}.self_attn.indexer.wk", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_Q_B: ( + "model.layers.{bid}.self_attn.indexer.wq_b", # DSA + ), + ############################################################################ # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1398a31db49..61f444a1680 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -513,6 +513,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -1627,6 +1631,10 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-arch.h b/src/llama-arch.h index 5997de9960b..da9153455bd 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -514,6 +514,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5188bec97e6..163fc234b73 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5525,6 +5525,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {n_embd_head_k_mla}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {n_embd_head_k_mla}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, n_embd_head_k_mla}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + if (i < (int) hparams.n_layer_dense_lead) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); diff --git a/src/llama-model.h b/src/llama-model.h index 7b580043b33..3af30c02d31 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -425,6 +425,13 @@ struct llama_layer { struct ggml_tensor * ssm_g_b = nullptr; struct ggml_tensor * ssm_o_norm = nullptr; + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; From 64184c12363507be56cc0b1922eecc60015d19b4 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 12 Feb 2026 01:04:28 +0100 Subject: [PATCH 05/10] add indexer gguf params --- convert_hf_to_gguf.py | 5 +++++ gguf-py/gguf/constants.py | 5 +++++ gguf-py/gguf/gguf_writer.py | 9 +++++++++ src/llama-arch.cpp | 4 ++++ src/llama-model.cpp | 6 +++--- 5 files changed, 26 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 67cfcf9a376..12cb35f9b6c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8716,6 +8716,11 @@ def set_gguf_parameters(self): if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + # DSA indexer parameters + self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) + self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) + self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) + def modify_tensors(self, data_torch, name, bid): yield from super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 05e131ac30c..09b7bbd6f61 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -180,6 +180,11 @@ class Attention: SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" + class Indexer: + HEAD_COUNT = "{arch}.attention.indexer.head_count" + KEY_LENGTH = "{arch}.attention.indexer.key_length" + TOP_K = "{arch}.attention.indexer.top_k" + class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 62172b24c38..1f0ab6fafc5 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -768,6 +768,15 @@ def add_key_length_mla(self, length: int) -> None: def add_value_length_mla(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_indexer_head_count(self, count: int | Sequence[int]) -> None: + self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count) + + def add_indexer_key_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length) + + def add_indexer_top_k(self, top_k: int) -> None: + self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 61f444a1680..29095bbab76 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2627,6 +2627,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 163fc234b73..629d2bae6ab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5526,10 +5526,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {n_embd_head_k_mla}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {n_embd_head_k_mla}, flags); + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {n_embd_head_k}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {n_embd_head_k}, flags); layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, n_embd_head_k_mla}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, n_embd_head_k}, flags); layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); if (i < (int) hparams.n_layer_dense_lead) { From d8a465650c0be31ac51374543e9c8697baf30eba Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 12 Feb 2026 01:12:42 +0100 Subject: [PATCH 06/10] loaded now --- convert_hf_to_gguf.py | 2 ++ src/llama-arch.cpp | 3 +++ src/llama-arch.h | 3 +++ src/llama-hparams.h | 5 +++++ src/llama-model.cpp | 19 ++++++++++++------- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12cb35f9b6c..5f264f4af06 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8705,6 +8705,8 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() + self.gguf_writer.add_leading_dense_block_count(3) # TODO: not to hard-code this for future models + rope_dim = self.hparams["qk_rope_head_dim"] partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 29095bbab76..53e42499085 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -223,6 +223,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index da9153455bd..1fc1a530068 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -227,6 +227,9 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 6c695bdbf66..fc260724710 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -194,6 +194,11 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 629d2bae6ab..d12894d9b30 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1842,6 +1842,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + // Expert gating function (GLM-4.5 uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { @@ -5503,7 +5508,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { int flags = 0; if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; } auto & layer = layers[i]; @@ -5526,12 +5532,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {n_embd_head_k}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {n_embd_head_k}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, n_embd_head_k}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); - + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); if (i < (int) hparams.n_layer_dense_lead) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); From 9bfafada7f2cf3f10c44b373fe327091d97a5e63 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 12 Feb 2026 12:41:22 +0100 Subject: [PATCH 07/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/gguf_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1f0ab6fafc5..e353d632b60 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -768,7 +768,7 @@ def add_key_length_mla(self, length: int) -> None: def add_value_length_mla(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) - def add_indexer_head_count(self, count: int | Sequence[int]) -> None: + def add_indexer_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count) def add_indexer_key_length(self, length: int) -> None: From 825540135ff3181aa3fed5de163517f34e642ba4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Feb 2026 16:24:55 +0100 Subject: [PATCH 08/10] update --- convert_hf_to_gguf.py | 50 ++++++++++++++++--------------------------- src/llama-model.cpp | 3 ++- src/llama-model.h | 1 + 3 files changed, 22 insertions(+), 32 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5f264f4af06..40794d822eb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1605,6 +1605,23 @@ def _set_vocab_glmedge(self): special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_glm(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_interns1(self): tokens: list[str] = [] toktypes: list[int] = [] @@ -8562,27 +8579,6 @@ def __init__(self, *args, **kwargs): self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) - # using staticmethod here to allow re-using it in other classes - @staticmethod - def set_vocab_glm(model: TextModel): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model.dir_model) - special_vocab = gguf.SpecialVocab(model.dir_model, load_merges=True) - tokens, toktypes, tokpre = model.get_vocab_base() - model.gguf_writer.add_tokenizer_model("gpt2") - model.gguf_writer.add_tokenizer_pre(tokpre) - model.gguf_writer.add_token_list(tokens) - model.gguf_writer.add_token_types(toktypes) - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - - special_vocab.add_to_gguf(model.gguf_writer) - def set_vocab(self): Glm4MoeModel.set_vocab_glm(self) @@ -8685,7 +8681,7 @@ class Glm4MoeLiteModel(DeepseekV2Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 def set_vocab(self): - Glm4MoeModel.set_vocab_glm(self) + return self._set_vocab_glm() @ModelBase.register("GlmMoeDsaForCausalLM") @@ -8700,20 +8696,15 @@ def __init__(self, *args, **kwargs): self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def set_vocab(self): - Glm4MoeModel.set_vocab_glm(self) + return self._set_vocab_glm() def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_leading_dense_block_count(3) # TODO: not to hard-code this for future models - rope_dim = self.hparams["qk_rope_head_dim"] partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) - # Expert gating function (sigmoid for GLM4_MOE) - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - # NextN/MTP prediction layers if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) @@ -8723,9 +8714,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) - def modify_tensors(self, data_torch, name, bid): - yield from super().modify_tensors(data_torch, name, bid) - @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d12894d9b30..a06fc65096e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -136,6 +136,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -1860,7 +1861,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; switch (hparams.n_layer) { - // TODO + case 78: type = LLM_TYPE_744B_A40B; break; default: type = LLM_TYPE_UNKNOWN; } } break; diff --git a/src/llama-model.h b/src/llama-model.h index 3af30c02d31..51c42a51a56 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -129,6 +129,7 @@ enum llm_type { LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; From 7b23cd920799f7e06a90ff99c8c792205723ec35 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 12 Feb 2026 19:08:20 +0100 Subject: [PATCH 09/10] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a06fc65096e..b873d6c1de3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1861,7 +1861,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; switch (hparams.n_layer) { - case 78: type = LLM_TYPE_744B_A40B; break; + case 79: type = LLM_TYPE_744B_A40B; break; default: type = LLM_TYPE_UNKNOWN; } } break; From 1daef5f85fe801257d84edffeabe52a0e1da216c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Feb 2026 11:57:45 +0100 Subject: [PATCH 10/10] minor fix and cleanup --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 40794d822eb..08f84a86a82 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8580,7 +8580,7 @@ def __init__(self, *args, **kwargs): self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def set_vocab(self): - Glm4MoeModel.set_vocab_glm(self) + return self._set_vocab_glm() def set_gguf_parameters(self): super().set_gguf_parameters() @@ -8691,7 +8691,6 @@ class GlmMoeDsaModel(DeepseekV2Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)