Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c755038
feat: support GLM 4.5 family of models
sammcj Jul 29, 2025
0edf732
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
6b478bb
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
9652812
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
07bb0dd
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
fae4df8
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
03fad04
feat: support GLM 4.5 family of models
sammcj Jul 30, 2025
b61fc91
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
999c07a
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
5baa607
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
62447f8
Update convert_hf_to_gguf.py
sammcj Jul 31, 2025
58898b5
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
ab3183e
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
6f3d94e
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
b25f462
feat: support GLM 4.5 family of models
sammcj Jul 31, 2025
c90f63a
feat: support GLM 4.5 family of models
sammcj Aug 1, 2025
bdfe09c
feat: support GLM 4.5 family of models
sammcj Aug 1, 2025
3d15c4a
feat: support GLM 4.5 family of models
sammcj Aug 1, 2025
dbfadb6
feat: support GLM 4.5 family of models
sammcj Aug 2, 2025
c3eb159
Merge branch 'master' into glm-4-5
createthis Aug 3, 2025
133c583
Apply suggestion from @CISC
sammcj Aug 3, 2025
15b79e8
Apply suggestion from @CISC
sammcj Aug 3, 2025
4e8cf30
Apply suggestion from @CISC
sammcj Aug 3, 2025
a5434b8
Apply suggestion from @CISC
sammcj Aug 3, 2025
0d272cc
Apply suggestion from @CISC
sammcj Aug 3, 2025
0da017c
Apply suggestion from @CISC
sammcj Aug 3, 2025
90871fb
Apply suggestion from @CISC
sammcj Aug 3, 2025
fb1d48d
Apply suggestion from @CISC
sammcj Aug 3, 2025
7f23adf
Apply suggestion from @CISC
sammcj Aug 3, 2025
9b0b1b4
Apply suggestion from @CISC
sammcj Aug 3, 2025
6ffa4a3
Apply suggestion from @CISC
sammcj Aug 3, 2025
166c002
Apply suggestion from @CISC
sammcj Aug 3, 2025
e342dfd
Apply suggestion from @CISC
sammcj Aug 3, 2025
4c24d5b
Apply suggestion from @CISC
sammcj Aug 3, 2025
b5ed4a8
Apply suggestion from @CISC
sammcj Aug 3, 2025
f36c3b7
Apply suggestion from @CISC
sammcj Aug 3, 2025
e867145
Apply suggestion from @CISC
sammcj Aug 3, 2025
e75ec99
feat: support GLM 4.5 family of models
sammcj Aug 3, 2025
25bb672
Update src/llama-model.cpp
sammcj Aug 3, 2025
f197e75
Update src/llama-model.cpp
sammcj Aug 3, 2025
d94d1fd
Update src/llama-model.cpp
sammcj Aug 3, 2025
6aa44f6
Update src/llama-model.cpp
sammcj Aug 3, 2025
5902e42
Update src/llama-model.cpp
sammcj Aug 3, 2025
49ff9e1
feat: support GLM 4.5 family of models
sammcj Aug 3, 2025
1b6cf0e
feat: support GLM 4.5 family of models
sammcj Aug 3, 2025
3a4ac7e
feat: support GLM 4.5 family of models - add rope neox
sammcj Aug 3, 2025
c56a513
feat: support GLM 4.5 family of models - aNoeda screenshot
sammcj Aug 3, 2025
69f0ae5
glm 4.5 set eos/eog/eot token to <|user|>
sammcj Aug 3, 2025
919adf6
Merge branch 'glm-4-5' of github.com:createthis/llama.cpp into glm-4-5
createthis Aug 3, 2025
e101d48
bump transformers and hfhub deps for glm 4.5 compat
sammcj Aug 3, 2025
183f568
Merge remote-tracking branch 'sammcj/glm-4-5' into glm-4-5
createthis Aug 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
Expand Down Expand Up @@ -6685,6 +6688,149 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Glm4MoeForCausalLM")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + 1
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

# Special tokens
# Note: Using <|endoftext|> (151329) for eos and eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - end of
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS
special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338

if "<sop>" in tokenizer.get_added_vocab():
special_vocab._set_special_token("sop", tokenizer.get_added_vocab()["<sop>"]) # 151333
if "<eop>" in tokenizer.get_added_vocab():
special_vocab._set_special_token("eop", tokenizer.get_added_vocab()["<eop>"]) # 151334

special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = (
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_rope_dimension_count(
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
)

# MoE parameters - Use only routed expert count (shared experts handled separately)
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
self.gguf_writer.add_expert_count(n_routed_experts)
if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(num_experts_per_tok)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)

# Expert gating function (sigmoid for GLM4_MOE)
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)

# Routed scaling factor
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)

# Normalise topk probabilities
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for multimodal variants

# Handle main token embedding (but not layer-specific NextN embeddings)
if name == "model.embed_tokens.weight" and ".layers." not in name:
return [(self.map_tensor_name("token_embd.weight"), data_torch)]

# Handle routed experts
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []

if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# Handle special NextN tensors - preserve for future MTP support
if (
".embed_tokens." in name
or ".shared_head." in name
or ".eh_proj." in name
or ".enorm." in name
or ".hnorm." in name
):
new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "")
return [(new_name, data_torch)]

new_name = self.map_tensor_name(name)

return [(new_name, data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
Expand All @@ -6701,7 +6847,7 @@ def set_vocab_chatglm3(self):
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
assert max(tokenizer.get_vocab().values()) < vocab_size
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
special_tokens = ["[MASK]", "[gMASK]", "sop", "eop"] + role_special_tokens
for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
if token_id == 0:
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
Expand Down
46 changes: 46 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
Expand Down Expand Up @@ -614,6 +615,12 @@ class MODEL_TENSOR(IntEnum):
A_MMPROJ_FC = auto()
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe)
NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe)
NEXTN_ENORM = auto() # nextn tensors (glm4moe)
NEXTN_HNORM = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe)


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -678,6 +685,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
Expand Down Expand Up @@ -936,6 +944,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
# NextN/MTP tensors (GLM4_MOE)
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm",
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head",
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -2124,6 +2139,37 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GLM4_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_GATE, # dense layers
MODEL_TENSOR.FFN_DOWN, # dense layers
MODEL_TENSOR.FFN_UP, # dense layers
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down
3 changes: 2 additions & 1 deletion models/templates/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ These templates can be updated with the following commands:
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
```
./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja
```
2 changes: 1 addition & 1 deletion requirements/requirements-convert_legacy_llama.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy~=1.26.4
sentencepiece~=0.2.0
transformers>=4.45.1,<5.0.0
transformers>=4.54.1,<5.0.0
gguf>=0.1.0
protobuf>=4.21.0,<5.0.0
2 changes: 1 addition & 1 deletion requirements/requirements-tool_bench.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
aiohttp~=3.9.3
pytest~=8.3.3
huggingface_hub~=0.23.2
huggingface_hub~=0.34.3
matplotlib~=3.10.0
numpy~=1.26.4
openai~=1.55.3
Expand Down
43 changes: 43 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
Expand Down Expand Up @@ -1391,6 +1392,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GLM4_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" },
},
},
{
LLM_ARCH_BITNET,
{
Expand Down Expand Up @@ -2181,6 +2216,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are loaded but never used (reserved for future MTP support)
// These tensors only exist in the last layer (layer 46 for GLM-4.5-Air) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
};

LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
Expand Down
7 changes: 7 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
Expand Down Expand Up @@ -409,6 +410,12 @@ enum llm_tensor {
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};

enum llm_tensor_layer {
Expand Down
8 changes: 4 additions & 4 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,8 @@ ggml_tensor * llm_graph_context::build_ffn(

if (down) {
cur = build_lora_mm(down, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
Expand Down Expand Up @@ -1391,8 +1391,8 @@ ggml_tensor * llm_graph_context::build_attn(

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
if (model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4_MOE: Only process first 46 transformer layers, skip NextN layer
n_layer_cache = hparams.n_layer - 1;
}

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
Expand Down
Loading
Loading