From c75503861226fd18f14f31a9b61d9d81f2e7fcf3 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 29 Jul 2025 17:52:59 +1000 Subject: [PATCH 01/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 209 ++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 51 +++++++ src/llama-arch.cpp | 42 ++++++ src/llama-arch.h | 7 + src/llama-model.cpp | 285 ++++++++++++++++++++++++++++++++++++++ src/llama-model.h | 2 + 6 files changed, 596 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3f5cefe007cca..12f22df249e57 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6578,6 +6578,215 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4MoeForCausalLM") +class Glm4MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + 1 + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + self.dir_model, trust_remote_code=True + ) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token( + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|observation|>"]) + special_vocab._set_special_token( + "unk", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token( + "bos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) + + # MoE parameters + if (n_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + # Note: expert_used_count is already set by parent class using num_experts_per_tok + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor + if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + # Normalise topk probabilities + if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + + _experts: list[dict[str, Tensor]] | None = None + _shared_experts: list[dict[str, Tensor]] | None = None + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for multimodal variants + + # Handle main token embedding (but not layer-specific NextN embeddings) + if name == "model.embed_tokens.weight" and ".layers." not in name: + return [(self.map_tensor_name("token_embd.weight"), data_torch)] + + # Handle routed experts + if name.find("mlp.experts") != -1 and "shared_experts" not in name: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + # Extend experts array if needed (for models where actual layers > num_hidden_layers) + while len(self._experts) <= bid: + self._experts.append({}) + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + # Generate GGUF tensor names for merged experts + if w_name == "down_proj": + new_name = f"blk.{bid}.ffn_down_exps.weight" + elif w_name == "gate_proj": + new_name = f"blk.{bid}.ffn_gate_exps.weight" + elif w_name == "up_proj": + new_name = f"blk.{bid}.ffn_up_exps.weight" + else: + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + # Handle expert gating input (routing gate) + if ".mlp.gate.e_score_correction_bias" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" + ) + return [(new_name, data_torch)] + elif ".mlp.gate.weight" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate.weight", ".ffn_gate_inp.weight" + ) + return [(new_name, data_torch)] + + # Handle shared expert tensors + if ".mlp.shared_experts." in name: + new_name = name.replace("model.layers.", "blk.").replace(".mlp.shared_experts.", ".ffn_") + if "gate_proj" in new_name: + new_name = new_name.replace("gate_proj", "gate_shexp") + elif "down_proj" in new_name: + new_name = new_name.replace("down_proj", "down_shexp") + elif "up_proj" in new_name: + new_name = new_name.replace("up_proj", "up_shexp") + return [(new_name, data_torch)] + + # Handle regular dense FFN layers (for hybrid dense/MoE architecture) + if ".mlp." in name and "experts" not in name and "_shexp" not in name: + if "gate_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate_proj.weight", ".ffn_gate.weight" + ) + elif "up_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.up_proj.weight", ".ffn_up.weight" + ) + elif "down_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.down_proj.weight", ".ffn_down.weight" + ) + else: + new_name = name + return [(self.map_tensor_name(new_name), data_torch)] + + # Handle special NextN tensors - preserve for future MTP support + if ( + ".embed_tokens." in name + or ".shared_head." in name + or ".eh_proj." in name + or ".enorm." in name + or ".hnorm." in name + ): + new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") + return [(new_name, data_torch)] + + # GLM tensor mapping - handle directly without map_tensor_name + if ".input_layernorm." in name: + new_name = name.replace("model.layers.", "blk.").replace(".input_layernorm.", ".attn_norm.") + return [(new_name, data_torch)] + elif ".post_attention_layernorm." in name: + new_name = name.replace("model.layers.", "blk.").replace(".post_attention_layernorm.", ".ffn_norm.") + return [(new_name, data_torch)] + elif ".self_attn." in name: + # Map GLM self_attn to standard attention naming + new_name = name.replace("model.layers.", "blk.").replace(".self_attn.", ".attn_") + if "q_proj" in new_name: + new_name = new_name.replace("q_proj", "q") + elif "k_proj" in new_name: + new_name = new_name.replace("k_proj", "k") + elif "v_proj" in new_name: + new_name = new_name.replace("v_proj", "v") + elif "o_proj" in new_name: + new_name = new_name.replace("o_proj", "output") + return [(new_name, data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c97b61d09c711..82b044e068c5a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() GLM4 = auto() + GLM4_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -422,6 +423,9 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() + FFN_GATE_EXPS = auto() # merged experts + FFN_DOWN_EXPS = auto() # merged experts + FFN_UP_EXPS = auto() # merged experts FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -609,6 +613,12 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() + NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe) + NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe) + NEXTN_ENORM = auto() # nextn tensors (glm4moe) + NEXTN_HNORM = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe) MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -673,6 +683,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -747,6 +758,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts + MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts + MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n @@ -929,6 +943,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", + # NextN/MTP tensors (GLM4_MOE) + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -2102,6 +2123,36 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, # dense layers + MODEL_TENSOR.FFN_DOWN, # dense layers + MODEL_TENSOR.FFN_UP, # dense layers + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXPS, + MODEL_TENSOR.FFN_DOWN_EXPS, + MODEL_TENSOR.FFN_UP_EXPS, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index dbf977443ae85..4dd63c3832a94 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1389,6 +1390,39 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + // NextN/MTP tensors - preserved but unused (treated as output tensors) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.46.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.46.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.46.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.46.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.46.shared_head.head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.46.shared_head.norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -2142,6 +2176,14 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are loaded but never used (reserved for future MTP support) + // These tensors only exist in the last layer (layer 46 for GLM-4.5-Air) and are treated as output tensors + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 8267a8d3aa491..73e546673bc7b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -407,6 +408,12 @@ enum llm_tensor { LLM_TENSOR_SHORTCONV_CONV, LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3aa9e6f91af9..26a7a34de7f32 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -111,6 +111,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -1417,6 +1419,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, 0); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, 0); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); + + // Expert gating function (GLM4_MOE uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4345,6 +4370,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // NextN/MTP tensors (preserved but unused) - treated as output tensors + create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_ENORM), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_HNORM), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM), { n_embd }, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // GLM-style attention with bias terms + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + const bool use_moe = + (hparams.n_expert > 0) && (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), { n_expert }, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + } + if (n_expert_used == 0) { + GGML_ASSERT(hparams.n_expert_used > 0 && + "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = + create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + + } + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13349,6 +13469,166 @@ struct llm_build_glm4 : public llm_graph_context { } }; +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FFN - hybrid dense/MoE layers + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE layer with shared experts + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared expert computation + ggml_tensor * cur_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_shexp, "ffn_shexp_out", il); + + // Combine MoE output with shared expert output + cur = ggml_add(ctx0, moe_out, cur_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17509,6 +17789,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -17781,6 +18065,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: + case LLM_ARCH_GLM4_MOE: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: diff --git a/src/llama-model.h b/src/llama-model.h index 094e23808a813..5e71247e37cec 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -103,6 +103,8 @@ enum llm_type { LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_106B_A12B, // GLM-4.5-Air (106B total, 12B active) + LLM_TYPE_355B_A32B, // GLM-4.5 (355B total, 32B active) LLM_TYPE_E2B, LLM_TYPE_E4B, }; From 0edf7321b450a83891a6b69954a87bd0bba16e93 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Jul 2025 11:50:36 +1000 Subject: [PATCH 02/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 26a7a34de7f32..6d30cdf1e51a0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4418,6 +4418,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE const bool use_moe = (hparams.n_expert > 0) && (static_cast(i) >= hparams.n_layer_dense_lead); @@ -13586,7 +13587,7 @@ struct llm_build_glm4_moe : public llm_graph_context { n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); From 6b478bb76f63dcee805afb566a04a9e44cd3c04d Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Jul 2025 13:10:47 +1000 Subject: [PATCH 03/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12f22df249e57..f472434502864 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -678,6 +678,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" + if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": + # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 + res = "gpt-2" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" From 9652812ca788a2458b652f397daa9919b904d0d1 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Jul 2025 13:32:59 +1000 Subject: [PATCH 04/48] feat: support GLM 4.5 family of models --- src/llama-arch.cpp | 14 +++++++------- src/llama-model.cpp | 15 ++++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4dd63c3832a94..1fb94ed4614a8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1414,13 +1414,13 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - // NextN/MTP tensors - preserved but unused (treated as output tensors) - { LLM_TENSOR_NEXTN_EH_PROJ, "blk.46.eh_proj" }, - { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.46.embed_tokens" }, - { LLM_TENSOR_NEXTN_ENORM, "blk.46.enorm" }, - { LLM_TENSOR_NEXTN_HNORM, "blk.46.hnorm" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.46.shared_head.head" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.46.shared_head.norm" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" }, }, }, { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6d30cdf1e51a0..e6c75e425d26b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4386,13 +4386,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // NextN/MTP tensors (preserved but unused) - treated as output tensors - create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_ENORM), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_HNORM), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM), { n_embd }, TENSOR_NOT_REQUIRED); + // NextN/MTP tensors (preserved but unused) - in final layer (dynamic layer number) + const int final_layer = n_layer - 1; // NextN tensors are in the last layer + create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; From 07bb0dd1e80f23aec522d774aaf29e2a3fcfc7c9 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Jul 2025 14:28:34 +1000 Subject: [PATCH 05/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e6c75e425d26b..3f541cd8a5375 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4388,12 +4388,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - in final layer (dynamic layer number) const int final_layer = n_layer - 1; // NextN tensors are in the last layer - create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; From fae4df8ee05c57bf2e2d5e3f0b094010f34dc86f Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Jul 2025 18:21:40 +1000 Subject: [PATCH 06/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f472434502864..ff7bda6eb001b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6603,19 +6603,19 @@ def set_vocab(self): self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + + # Set special tokens special_vocab._set_special_token( "eos", tokenizer.get_added_vocab()["<|endoftext|>"] ) special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|observation|>"]) special_vocab._set_special_token( "unk", tokenizer.get_added_vocab()["<|endoftext|>"] ) special_vocab._set_special_token( "bos", tokenizer.get_added_vocab()["<|endoftext|>"] ) + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): From 03fad044941891a4b69a190450d0ae74f2954812 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 07:39:06 +1000 Subject: [PATCH 07/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3f541cd8a5375..0626379473777 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13584,7 +13584,7 @@ struct llm_build_glm4_moe : public llm_graph_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - nullptr, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, From b61fc918829c4fe33a94971218a35fe097a38553 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 12:09:02 +1000 Subject: [PATCH 08/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0626379473777..66bb67dcc791b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1429,6 +1429,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, 0); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); // Expert gating function (GLM4_MOE uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); @@ -13587,7 +13588,7 @@ struct llm_build_glm4_moe : public llm_graph_context { model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + true, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); From 999c07a2e28709e429d659561e98bd9c3ca25356 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 12:49:02 +1000 Subject: [PATCH 09/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 14 ++++++++++++++ src/llama-model.cpp | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ff7bda6eb001b..8714079e2c3bc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6616,6 +6616,20 @@ def set_vocab(self): "bos", tokenizer.get_added_vocab()["<|endoftext|>"] ) special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + # Fix chat template syntax error in GLM-4.5 models + if special_vocab.chat_template and isinstance(special_vocab.chat_template, str): + # Fix multiple syntax issues in GLM-4.5 chat template + template = special_vocab.chat_template + # Fix nested double quotes issue + template = template.replace('endswith("/nothink")', "endswith('/nothink')") + # Fix any other potential parentheses/tuple issues + template = template.replace( + "not visible_text(m.content).endswith('/nothink'))", + "not visible_text(m.content).endswith('/nothink')" + ) + special_vocab.chat_template = template + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 66bb67dcc791b..d14f35d73d6cc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1430,6 +1430,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); // Expert gating function (GLM4_MOE uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); @@ -13587,7 +13588,7 @@ struct llm_build_glm4_moe : public llm_graph_context { model.layers[il].ffn_down_exps, model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - LLM_FFN_SILU, true, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); From 5baa60717c950cfc6251d729c8a794c03e038401 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 17:45:48 +1000 Subject: [PATCH 10/48] feat: support GLM 4.5 family of models --- src/llama-arch.cpp | 1 + src/llama-model.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1fb94ed4614a8..90657ad12a171 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1414,6 +1414,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d14f35d73d6cc..183bdf88a81ad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4388,8 +4388,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // NextN/MTP tensors (preserved but unused) - in final layer (dynamic layer number) - const int final_layer = n_layer - 1; // NextN tensors are in the last layer + // NextN/MTP tensors (preserved but unused) - only in final layer (46 for Air, 92 for GLM-4.5) + const int final_layer = n_layer - 1; // NextN tensors are in the last layer only create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); @@ -4406,9 +4406,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -4429,7 +4429,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // MoE layers layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), { n_expert }, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), { n_expert }, 0); if (n_expert == 0) { GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); From 62447f82ba8b612a9e114bd4d12335c9d2b61497 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 21:45:51 +1000 Subject: [PATCH 11/48] Update convert_hf_to_gguf.py --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8714079e2c3bc..0b93091a25ada 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -680,7 +680,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: res = "glm4" if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 - res = "gpt-2" + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" From 58898b5ea7d872366f09755468b129fe07cca979 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 21:54:43 +1000 Subject: [PATCH 12/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 54 +++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0b93091a25ada..9405077fbcba9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6604,29 +6604,54 @@ def set_vocab(self): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - # Set special tokens + # Special tokens + # BOS should be [gMASK] (151331), EOT should be <|endoftext|> (151329) special_vocab._set_special_token( "eos", tokenizer.get_added_vocab()["<|endoftext|>"] ) - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token( + "eot", tokenizer.get_added_vocab()["<|endoftext|>"] + ) special_vocab._set_special_token( "unk", tokenizer.get_added_vocab()["<|endoftext|>"] ) special_vocab._set_special_token( - "bos", tokenizer.get_added_vocab()["<|endoftext|>"] + "bos", tokenizer.get_added_vocab()["[gMASK]"] # 151331 ) special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - # Fix chat template syntax error in GLM-4.5 models + # Fix chat template syntax error if special_vocab.chat_template and isinstance(special_vocab.chat_template, str): # Fix multiple syntax issues in GLM-4.5 chat template template = special_vocab.chat_template - # Fix nested double quotes issue - template = template.replace('endswith("/nothink")', "endswith('/nothink')") - # Fix any other potential parentheses/tuple issues + # Fix missing closing parenthesis in conditional expression + template = template.replace( + 'endswith("/nothink")) else', + 'endswith("/nothink"))) else' + ) + template = template.replace( + "endswith('/nothink')) else", + "endswith('/nothink'))) else" + ) + # llama.cpp's C++ Jinja2 parser doesn't support visible_text() or .endswith() template = template.replace( - "not visible_text(m.content).endswith('/nothink'))", - "not visible_text(m.content).endswith('/nothink')" + "visible_text(m.content).endswith('/nothink')", + "'/nothink' in m.content" + ) + template = template.replace( + "visible_text(m.content).endswith(\"/nothink\")", + "\"/nothink\" in m.content" + ) + # Remove visible_text() function calls entirely as they're not supported + template = template.replace("visible_text(m.content)", "m.content") + # Fix parenthesis mismatch in chat template + template = template.replace( + 'not "/nothink" in m.content)) else', + 'not "/nothink" in m.content) else' + ) # Remove extra closing parenthesis + template = template.replace( + "not '/nothink' in m.content)) else", + "not '/nothink' in m.content) else" ) special_vocab.chat_template = template @@ -6642,10 +6667,9 @@ def set_gguf_parameters(self): int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) ) - # MoE parameters - if (n_experts := self.hparams.get("n_routed_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - # Note: expert_used_count is already set by parent class using num_experts_per_tok + # MoE parameters - Use only routed expert count (shared experts handled separately) + if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: @@ -6721,10 +6745,10 @@ def modify_tensors( else: return [] - # Handle expert gating input (routing gate) + # Handle expert gating input (routing gate) - routed experts only if ".mlp.gate.e_score_correction_bias" in name: new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" + ".mlp.gate.e_score_correction_bias", ".exp_probs_b" ) return [(new_name, data_torch)] elif ".mlp.gate.weight" in name: From ab3183e646aae2bbc65bedf874c93e2cec30122d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Jul 2025 22:16:40 +1000 Subject: [PATCH 13/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 183bdf88a81ad..0c380894b240e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13580,6 +13580,15 @@ struct llm_build_glm4_moe : public llm_graph_context { const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; + // Compute shared expert output first + ggml_tensor * cur_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_shexp, "ffn_shexp_out", il); + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, @@ -13594,16 +13603,12 @@ struct llm_build_glm4_moe : public llm_graph_context { il); cb(moe_out, "ffn_moe_out", il); - // Add shared expert computation - ggml_tensor * cur_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_shexp, "ffn_shexp_out", il); + // For GLM4_MOE: Shared expert is always active alongside routed experts + // Apply proper scaling to shared expert to match architectural design + cur_shexp = ggml_scale(ctx0, cur_shexp, hparams.expert_weights_scale); + cb(cur_shexp, "ffn_shexp_scaled", il); - // Combine MoE output with shared expert output + // Combine with proper mathematical balance cur = ggml_add(ctx0, moe_out, cur_shexp); cb(cur, "ffn_out", il); } From 6f3d94eba4e699f57f9716302b4cdc43307415f1 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 07:58:57 +1000 Subject: [PATCH 14/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 46 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0c380894b240e..349fb3afa17b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13580,36 +13580,34 @@ struct llm_build_glm4_moe : public llm_graph_context { const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; - // Compute shared expert output first - ggml_tensor * cur_shexp = build_ffn(cur, + // Save original input for shared expert + ggml_tensor * residuals = cur; + + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(residuals, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_shexp, "ffn_shexp_out", il); - - ggml_tensor * moe_out = - build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(moe_out, "ffn_moe_out", il); - - // For GLM4_MOE: Shared expert is always active alongside routed experts - // Apply proper scaling to shared expert to match architectural design - cur_shexp = ggml_scale(ctx0, cur_shexp, hparams.expert_weights_scale); - cb(cur_shexp, "ffn_shexp_scaled", il); + cb(shared_out, "ffn_shexp_out", il); - // Combine with proper mathematical balance - cur = ggml_add(ctx0, moe_out, cur_shexp); + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); cb(cur, "ffn_out", il); } From b25f462395c51925ec8350f0130da9cd0ac3c72a Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 08:26:11 +1000 Subject: [PATCH 15/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9405077fbcba9..7dcf16d3cac5a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6620,40 +6620,6 @@ def set_vocab(self): ) special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - # Fix chat template syntax error - if special_vocab.chat_template and isinstance(special_vocab.chat_template, str): - # Fix multiple syntax issues in GLM-4.5 chat template - template = special_vocab.chat_template - # Fix missing closing parenthesis in conditional expression - template = template.replace( - 'endswith("/nothink")) else', - 'endswith("/nothink"))) else' - ) - template = template.replace( - "endswith('/nothink')) else", - "endswith('/nothink'))) else" - ) - # llama.cpp's C++ Jinja2 parser doesn't support visible_text() or .endswith() - template = template.replace( - "visible_text(m.content).endswith('/nothink')", - "'/nothink' in m.content" - ) - template = template.replace( - "visible_text(m.content).endswith(\"/nothink\")", - "\"/nothink\" in m.content" - ) - # Remove visible_text() function calls entirely as they're not supported - template = template.replace("visible_text(m.content)", "m.content") - # Fix parenthesis mismatch in chat template - template = template.replace( - 'not "/nothink" in m.content)) else', - 'not "/nothink" in m.content) else' - ) # Remove extra closing parenthesis - template = template.replace( - "not '/nothink' in m.content)) else", - "not '/nothink' in m.content) else" - ) - special_vocab.chat_template = template special_vocab.add_to_gguf(self.gguf_writer) From c90f63aa31c5edcb5e899994fe2134144d1c5a42 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 22:05:24 +1000 Subject: [PATCH 16/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 349fb3afa17b3..57e337d0a16d9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13476,6 +13476,7 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13549,7 +13550,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } From bdfe09c5c06738310d363c1ab1a474e25e15ceb5 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 22:27:13 +1000 Subject: [PATCH 17/48] feat: support GLM 4.5 family of models --- src/llama-graph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1b9cc4aec0632..32ee267631e91 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -760,8 +760,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1481,8 +1481,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } From 3d15c4a940dc262c7ca4ea1bc6c246a6881a19bf Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 23:20:22 +1000 Subject: [PATCH 18/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7dcf16d3cac5a..b19842de9c2ed 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6605,12 +6605,12 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) # Special tokens - # BOS should be [gMASK] (151331), EOT should be <|endoftext|> (151329) + # BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per official config special_vocab._set_special_token( - "eos", tokenizer.get_added_vocab()["<|endoftext|>"] + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - official EOS token ) special_vocab._set_special_token( - "eot", tokenizer.get_added_vocab()["<|endoftext|>"] + "eot", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - same as EOS ) special_vocab._set_special_token( "unk", tokenizer.get_added_vocab()["<|endoftext|>"] @@ -6620,6 +6620,9 @@ def set_vocab(self): ) special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + if "/nothink" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360 + # Note: and are regular tokens (special=false in official config), not special tokens special_vocab.add_to_gguf(self.gguf_writer) @@ -6654,6 +6657,9 @@ def set_gguf_parameters(self): if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + # GLM models should not prepend BOS token + self.gguf_writer.add_add_bos_token(False) + _experts: list[dict[str, Tensor]] | None = None _shared_experts: list[dict[str, Tensor]] | None = None From dbfadb661e15e080df9962dca35f23eba9610955 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 2 Aug 2025 15:31:55 +1000 Subject: [PATCH 19/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 26 ++++++++++++++++--- models/templates/README.md | 3 ++- .../requirements-convert_legacy_llama.txt | 2 +- src/llama-kv-cache-unified.cpp | 4 +++ src/llama-model.cpp | 9 +++++-- 5 files changed, 36 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b19842de9c2ed..b386177c2eeb4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6605,9 +6605,9 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) # Special tokens - # BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per official config + # BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per tokenizer analysis special_vocab._set_special_token( - "eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - official EOS token + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - correct EOS token ) special_vocab._set_special_token( "eot", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - same as EOS @@ -6620,9 +6620,25 @@ def set_vocab(self): ) special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - if "/nothink" in tokenizer.get_added_vocab(): - special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360 + if "" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("sop", tokenizer.get_added_vocab()[""]) # 151333 + if "" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("eop", tokenizer.get_added_vocab()[""]) # 151334 + if "[sMASK]" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("smask", tokenizer.get_added_vocab()["[sMASK]"]) # 151332 + + # TODO: clean up once decided on an approach to think and /nothink + # + # Previously: + # if "/nothink" in tokenizer.get_added_vocab(): + # special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360 # Note: and are regular tokens (special=false in official config), not special tokens + # + # Latest thinking is: + # NOTE: /nothink token exists but causes generation issues as mentioned in + # https://huggingface.co/zai-org/GLM-4.5/discussions/9 + # "it is a very special token. Even as input, it will be encoded into a special token, causing generation issues." + # Therefore we do NOT add it to avoid generation problems special_vocab.add_to_gguf(self.gguf_writer) @@ -6639,6 +6655,8 @@ def set_gguf_parameters(self): # MoE parameters - Use only routed expert count (shared experts handled separately) if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: self.gguf_writer.add_expert_count(n_routed_experts) + if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(num_experts_per_tok) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: diff --git a/models/templates/README.md b/models/templates/README.md index 35b6386dd0649..2e8eaa5953b86 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -21,4 +21,5 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja ./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja ./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja -``` \ No newline at end of file +./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja +``` diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 859204b27ebb8..80676138037bd 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.45.1,<5.0.0 +transformers>=4.54.1,<5.0.0 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 321dc79fc36ab..7b9987edd03ff 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (model.arch == LLM_ARCH_GEMMA3N) { n_layer_cache = 20; } + if (model.arch == LLM_ARCH_GLM4_MOE) { + // GLM4_MOE: Only process first 46 transformer layers, skip NextN layer + n_layer_cache = hparams.n_layer - 1; + } // create a context for each buffer type std::map ctx_map; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 57e337d0a16d9..9f4aa9c878820 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4397,6 +4397,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + // Load ALL tensors including NextN layer to satisfy tensor count (803) + // but only PROCESS first 46 transformer layers in forward pass for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -13492,7 +13494,10 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // Only process first 46 transformer layers (skip NextN layer 46) + // Layer 46 tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - 1; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -13554,7 +13559,7 @@ struct llm_build_glm4_moe : public llm_graph_context { Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } From 133c5838253d7355b0aa0bd0fb26a3284a34374a Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:16 +1000 Subject: [PATCH 20/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 82b044e068c5a..3762dd90e5d2a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -423,9 +423,6 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() - FFN_GATE_EXPS = auto() # merged experts - FFN_DOWN_EXPS = auto() # merged experts - FFN_UP_EXPS = auto() # merged experts FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() From 15b79e8eec72338298a556f89decaa4715f0e50e Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:22 +1000 Subject: [PATCH 21/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3762dd90e5d2a..d3e49ee99cce3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -755,9 +755,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts - MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts - MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n From 4e8cf30ab2bcf2103218ffe4d7468de7f0db6a3b Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:29 +1000 Subject: [PATCH 22/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d3e49ee99cce3..f5ce3c64c233e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2122,6 +2122,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, From a5434b8a5a5a76f00ee60c4f3e3a1ff91975e473 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:36 +1000 Subject: [PATCH 23/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f5ce3c64c233e..73359670fa7ad 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2129,7 +2129,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, - MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, # dense layers MODEL_TENSOR.FFN_DOWN, # dense layers MODEL_TENSOR.FFN_UP, # dense layers From 0d272cce21344888020b475ae03d89b2cbe1491c Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:43 +1000 Subject: [PATCH 24/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 73359670fa7ad..b74836078b818 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2133,9 +2133,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, # dense layers MODEL_TENSOR.FFN_UP, # dense layers MODEL_TENSOR.FFN_GATE_INP, - MODEL_TENSOR.FFN_GATE_EXPS, - MODEL_TENSOR.FFN_DOWN_EXPS, - MODEL_TENSOR.FFN_UP_EXPS, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, From 0da017ca231292c6e5881ec638466602580e0be0 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:50 +1000 Subject: [PATCH 25/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b74836078b818..abb18db8ae59b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2139,6 +2139,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, # NextN/MTP tensors - preserved but unused MODEL_TENSOR.NEXTN_EH_PROJ, MODEL_TENSOR.NEXTN_EMBED_TOKENS, From 90871fb92e6490f16091fffa831809dae927d4c2 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:56:57 +1000 Subject: [PATCH 26/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-arch.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 90657ad12a171..340465cf91e5f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1397,6 +1397,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, From fb1d48d8d0cd52778e52ddd92be56331edf663d7 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:57:05 +1000 Subject: [PATCH 27/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-arch.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 340465cf91e5f..a6a69839ecb63 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1404,7 +1404,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers From 7f23adf69a5e2d100449e084bc3fc7176599c4a7 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:57:43 +1000 Subject: [PATCH 28/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b386177c2eeb4..0a5bea9694cff 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6594,9 +6594,7 @@ def __init__(self, *args, **kwargs): def set_vocab(self): from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - self.dir_model, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") From 9b0b1b4f244c3e0c95f2b3a565e7cfc50d7245c9 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:57:56 +1000 Subject: [PATCH 29/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0a5bea9694cff..bb8bf6213004d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6718,16 +6718,10 @@ def modify_tensors( del self._experts[bid][ename] data_torch = torch.stack(datas, dim=0) - # Generate GGUF tensor names for merged experts - if w_name == "down_proj": - new_name = f"blk.{bid}.ffn_down_exps.weight" - elif w_name == "gate_proj": - new_name = f"blk.{bid}.ffn_gate_exps.weight" - elif w_name == "up_proj": - new_name = f"blk.{bid}.ffn_up_exps.weight" - else: - merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) tensors.append((new_name, data_torch)) return tensors else: From 6ffa4a33eebedc8f643121be6a424913f0426a41 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:07 +1000 Subject: [PATCH 30/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bb8bf6213004d..d062da8466f4d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6673,9 +6673,6 @@ def set_gguf_parameters(self): if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - # GLM models should not prepend BOS token - self.gguf_writer.add_add_bos_token(False) - _experts: list[dict[str, Tensor]] | None = None _shared_experts: list[dict[str, Tensor]] | None = None From 166c0025aac1419423b94bdbb09ad52f62936b48 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:14 +1000 Subject: [PATCH 31/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d062da8466f4d..465171efa5d32 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6776,27 +6776,9 @@ def modify_tensors( new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") return [(new_name, data_torch)] - # GLM tensor mapping - handle directly without map_tensor_name - if ".input_layernorm." in name: - new_name = name.replace("model.layers.", "blk.").replace(".input_layernorm.", ".attn_norm.") - return [(new_name, data_torch)] - elif ".post_attention_layernorm." in name: - new_name = name.replace("model.layers.", "blk.").replace(".post_attention_layernorm.", ".ffn_norm.") - return [(new_name, data_torch)] - elif ".self_attn." in name: - # Map GLM self_attn to standard attention naming - new_name = name.replace("model.layers.", "blk.").replace(".self_attn.", ".attn_") - if "q_proj" in new_name: - new_name = new_name.replace("q_proj", "q") - elif "k_proj" in new_name: - new_name = new_name.replace("k_proj", "k") - elif "v_proj" in new_name: - new_name = new_name.replace("v_proj", "v") - elif "o_proj" in new_name: - new_name = new_name.replace("o_proj", "output") - return [(new_name, data_torch)] + new_name = self.map_tensor_name(name) - return super().modify_tensors(data_torch, name, bid) + return [(new_name, data_torch)] def prepare_tensors(self): super().prepare_tensors() From e342dfddf339cb0de47774e7a079f2b9921da846 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:22 +1000 Subject: [PATCH 32/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 465171efa5d32..e358948277c62 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6736,35 +6736,6 @@ def modify_tensors( ) return [(new_name, data_torch)] - # Handle shared expert tensors - if ".mlp.shared_experts." in name: - new_name = name.replace("model.layers.", "blk.").replace(".mlp.shared_experts.", ".ffn_") - if "gate_proj" in new_name: - new_name = new_name.replace("gate_proj", "gate_shexp") - elif "down_proj" in new_name: - new_name = new_name.replace("down_proj", "down_shexp") - elif "up_proj" in new_name: - new_name = new_name.replace("up_proj", "up_shexp") - return [(new_name, data_torch)] - - # Handle regular dense FFN layers (for hybrid dense/MoE architecture) - if ".mlp." in name and "experts" not in name and "_shexp" not in name: - if "gate_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate_proj.weight", ".ffn_gate.weight" - ) - elif "up_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.up_proj.weight", ".ffn_up.weight" - ) - elif "down_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.down_proj.weight", ".ffn_down.weight" - ) - else: - new_name = name - return [(self.map_tensor_name(new_name), data_torch)] - # Handle special NextN tensors - preserve for future MTP support if ( ".embed_tokens." in name From 4c24d5b1af8d32ec30e274aff2d4f5ae716fb130 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:31 +1000 Subject: [PATCH 33/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e358948277c62..9d6e16eb56dbf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6724,17 +6724,8 @@ def modify_tensors( else: return [] - # Handle expert gating input (routing gate) - routed experts only - if ".mlp.gate.e_score_correction_bias" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.e_score_correction_bias", ".exp_probs_b" - ) - return [(new_name, data_torch)] - elif ".mlp.gate.weight" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.weight", ".ffn_gate_inp.weight" - ) - return [(new_name, data_torch)] + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") # Handle special NextN tensors - preserve for future MTP support if ( From b5ed4a8a0e399e25c5e3fc333ccacd27e16779da Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:39 +1000 Subject: [PATCH 34/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9d6e16eb56dbf..468e50df1223d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6696,10 +6696,6 @@ def modify_tensors( if self._experts is None: self._experts = [{} for _ in range(self.block_count)] - # Extend experts array if needed (for models where actual layers > num_hidden_layers) - while len(self._experts) <= bid: - self._experts.append({}) - self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: From f36c3b7b631c53ce9f4750bfa312af37a9d0dbed Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:46 +1000 Subject: [PATCH 35/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 468e50df1223d..2cb47581e11f0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6689,7 +6689,7 @@ def modify_tensors( return [(self.map_tensor_name("token_embd.weight"), data_torch)] # Handle routed experts - if name.find("mlp.experts") != -1 and "shared_experts" not in name: + if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] assert bid is not None From e8671457449608eef29415db108c480dcd0cf38b Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 17:58:54 +1000 Subject: [PATCH 36/48] Apply suggestion from @CISC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cb47581e11f0..ef8d181a343eb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6674,7 +6674,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_weights_norm(norm_topk_prob) _experts: list[dict[str, Tensor]] | None = None - _shared_experts: list[dict[str, Tensor]] | None = None def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None From e75ec9940ddfdd4e0b1944e167aa8a8933b22452 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:08:03 +1000 Subject: [PATCH 37/48] feat: support GLM 4.5 family of models --- convert_hf_to_gguf.py | 6 ++---- convert_hf_to_gguf_update.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ef8d181a343eb..a4ee706abc69d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -679,7 +679,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": - # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 + # ref: https://huggingface.co/zai-org/GLM-4.5-Air res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 @@ -6622,8 +6622,6 @@ def set_vocab(self): special_vocab._set_special_token("sop", tokenizer.get_added_vocab()[""]) # 151333 if "" in tokenizer.get_added_vocab(): special_vocab._set_special_token("eop", tokenizer.get_added_vocab()[""]) # 151334 - if "[sMASK]" in tokenizer.get_added_vocab(): - special_vocab._set_special_token("smask", tokenizer.get_added_vocab()["[sMASK]"]) # 151332 # TODO: clean up once decided on an approach to think and /nothink # @@ -6762,7 +6760,7 @@ def set_vocab_chatglm3(self): vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] - special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + special_tokens = ["[MASK]", "[gMASK]", "sop", "eop"] + role_special_tokens for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index abaf2ea9a1248..ea221fb1b5c4c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -138,6 +138,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes From 25bb67213a3d7762082ce8f055766d88ec2d8123 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:10:25 +1000 Subject: [PATCH 38/48] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9f4aa9c878820..63169d3ad2a69 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13564,13 +13564,16 @@ struct llm_build_glm4_moe : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } + // Post-attention norm + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // FFN - hybrid dense/MoE layers - cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) if (static_cast(il) < hparams.n_layer_dense_lead) { // Dense FFN layer From f197e752aa71a3a00afb1c54e46d1d51b9942fbb Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:10:36 +1000 Subject: [PATCH 39/48] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 63169d3ad2a69..45ce37eeb59eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4420,7 +4420,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor( tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE From d94d1fdb1e4983457efa6124c665746885e8e1c9 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:11:24 +1000 Subject: [PATCH 40/48] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 45ce37eeb59eb..55c4592eb60a7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4424,8 +4424,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE - const bool use_moe = - (hparams.n_expert > 0) && (static_cast(i) >= hparams.n_layer_dense_lead); + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); if (use_moe) { // MoE layers From 6aa44f65949bf9e042e7e64cbb1a9f7a63f0753f Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:11:36 +1000 Subject: [PATCH 41/48] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 55c4592eb60a7..a07ddb326cd7b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13477,7 +13477,6 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); From 5902e4200a3fae0b866162fce33560c82e4a68d6 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:12:18 +1000 Subject: [PATCH 42/48] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a07ddb326cd7b..f6f628ecd9505 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1421,14 +1421,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GLM4_MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, 0); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, 0); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); From 49ff9e1b5fb32dd8590a827229dfe64dc0174d6c Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:12:02 +1000 Subject: [PATCH 43/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f6f628ecd9505..7abe00b67c8d5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4430,7 +4430,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // MoE layers layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), { n_expert }, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0); if (n_expert == 0) { GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); From 1b6cf0e17003e37d4d64858371e66db6a6898590 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:14:21 +1000 Subject: [PATCH 44/48] feat: support GLM 4.5 family of models --- src/llama-model.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7abe00b67c8d5..3ad1e0d670285 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4378,6 +4378,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_expert_shared = hparams.n_expert_shared; + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -4432,14 +4435,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0); - if (n_expert == 0) { - GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); - } - if (n_expert_used == 0) { - GGML_ASSERT(hparams.n_expert_used > 0 && - "n_expert_used must be > 0 for GLM4_MOE MoE layers"); - } - // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; From 3a4ac7e1457ec40f67e235884dc76a585c4d5fec Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 18:27:45 +1000 Subject: [PATCH 45/48] feat: support GLM 4.5 family of models - add rope neox --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3ad1e0d670285..6d37c95810255 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18074,7 +18074,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: - case LLM_ARCH_GLM4_MOE: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: @@ -18127,6 +18126,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: + case LLM_ARCH_GLM4_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: From c56a5131b8c02d824ffe486bf6db352661c659da Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 19:14:02 +1000 Subject: [PATCH 46/48] feat: support GLM 4.5 family of models - aNoeda screenshot --- src/llama-model.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6d37c95810255..2533439c06957 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13557,16 +13557,13 @@ struct llm_build_glm4_moe : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Post-attention norm - cur = build_norm(cur, - model.layers[il].attn_post_norm, - NULL, - LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) if (static_cast(il) < hparams.n_layer_dense_lead) { // Dense FFN layer @@ -13582,9 +13579,6 @@ struct llm_build_glm4_moe : public llm_graph_context { const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; - // Save original input for shared expert - ggml_tensor * residuals = cur; - // Process routed experts using existing MoE infrastructure ggml_tensor * routed_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, @@ -13600,7 +13594,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(routed_out, "ffn_moe_out", il); // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(residuals, + ggml_tensor * shared_out = build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, From 69f0ae5a592ee4ba553c867021fa3308a1f80ec6 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 21:57:36 +1000 Subject: [PATCH 47/48] glm 4.5 set eos/eog/eot token to <|user|> --- convert_hf_to_gguf.py | 32 ++++++-------------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a4ee706abc69d..847a7c4cb6c3e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6603,19 +6603,12 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) # Special tokens - # BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per tokenizer analysis - special_vocab._set_special_token( - "eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - correct EOS token - ) - special_vocab._set_special_token( - "eot", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - same as EOS - ) - special_vocab._set_special_token( - "unk", tokenizer.get_added_vocab()["<|endoftext|>"] - ) - special_vocab._set_special_token( - "bos", tokenizer.get_added_vocab()["[gMASK]"] # 151331 - ) + # Note: Using <|endoftext|> (151329) for eos and eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - end of + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS + special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 if "" in tokenizer.get_added_vocab(): @@ -6623,19 +6616,6 @@ def set_vocab(self): if "" in tokenizer.get_added_vocab(): special_vocab._set_special_token("eop", tokenizer.get_added_vocab()[""]) # 151334 - # TODO: clean up once decided on an approach to think and /nothink - # - # Previously: - # if "/nothink" in tokenizer.get_added_vocab(): - # special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360 - # Note: and are regular tokens (special=false in official config), not special tokens - # - # Latest thinking is: - # NOTE: /nothink token exists but causes generation issues as mentioned in - # https://huggingface.co/zai-org/GLM-4.5/discussions/9 - # "it is a very special token. Even as input, it will be encoded into a special token, causing generation issues." - # Therefore we do NOT add it to avoid generation problems - special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): From e101d48740af67961d280da8ab078ebeaedfa30d Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 3 Aug 2025 22:06:31 +1000 Subject: [PATCH 48/48] bump transformers and hfhub deps for glm 4.5 compat --- requirements/requirements-tool_bench.txt | 2 +- tools/server/tests/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index b94521fc7fa72..edcbe8b839316 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub~=0.23.2 +huggingface_hub~=0.34.3 matplotlib~=3.10.0 numpy~=1.26.4 openai~=1.55.3 diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt index 15d024914e841..43190c104781f 100644 --- a/tools/server/tests/requirements.txt +++ b/tools/server/tests/requirements.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub~=0.23.2 +huggingface_hub~=0.34.3 numpy~=1.26.4 openai~=1.55.3 prometheus-client~=0.20.0