Skip to content

Commit 740b88e

Browse files
youth123xingxingqiaoUmpire2018
authored and
Neo Zhang
committed
llama : support glm3 and glm4 (ggml-org#8031)
* add chatglm3-6b model support huggingface model: https://hf-mirror.com/THUDM/chatglm3-6b Signed-off-by: XingXing Qiao <[email protected]> * remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model Signed-off-by: XingXing Qiao <[email protected]> * fix lint error Signed-off-by: XingXing Qiao <[email protected]> * optimize convert-hf-to-gguf.py for chatglm model Signed-off-by: XingXing Qiao <[email protected]> * support glm-4-9b-chat Signed-off-by: XingXing Qiao <[email protected]> * fix eos tokens to glm4 * remove unused log * add preprocess to chatglm3 and chatglm4 * add eos_id_list to llama.cpp * fix code style * fix code style * fix conflicts * fix conflicts * Revert "add eos_id_list to llama.cpp" This reverts commit 3a4d579. * set <|endoftext|> as eos and <|user|> as eot * fix chat template bug * add comment to glm prefix and suffix * fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration * fix chat template bug * fix codestyle * fix conflicts * modified the general name of glm model * fix conflicts * remove prefix and suffix * use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3 * fix: resolve Flake8 errors in `convert-hf-to-gguf.py` - Fix E302 by adding two blank lines before top-level function definitions - Replace print statements to fix NP100 - Fix E303 by ensuring only one blank line between lines of code * fix rope ratio to solve incorrect answers * fix by comments --------- Signed-off-by: XingXing Qiao <[email protected]> Co-authored-by: XingXing Qiao <[email protected]> Co-authored-by: Umpire2018 <[email protected]>
1 parent 9b077de commit 740b88e

File tree

6 files changed

+455
-25
lines changed

6 files changed

+455
-25
lines changed

convert_hf_to_gguf.py

+187
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
487487
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
488488
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
489489
res = "jina-v2-code"
490+
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
491+
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
492+
res = "chatglm-bpe"
490493
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
491494
# ref: https://huggingface.co/LumiOpen/Viking-7B
492495
res = "viking"
@@ -3175,6 +3178,190 @@ def write_tensors(self):
31753178
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
31763179

31773180

3181+
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
3182+
class ChatGLMModel(Model):
3183+
model_arch = gguf.MODEL_ARCH.CHATGLM
3184+
3185+
def set_vocab_chatglm3(self):
3186+
dir_model = self.dir_model
3187+
hparams = self.hparams
3188+
tokens: list[bytearray] = []
3189+
toktypes: list[int] = []
3190+
scores: list[float] = []
3191+
3192+
from transformers import AutoTokenizer
3193+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
3194+
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
3195+
assert max(tokenizer.get_vocab().values()) < vocab_size
3196+
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
3197+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
3198+
for token_id in range(vocab_size):
3199+
piece = tokenizer._convert_id_to_token(token_id)
3200+
if token_id == 0:
3201+
piece = "<unk>"
3202+
elif token_id == 1:
3203+
piece = "<bos>"
3204+
elif token_id == 2:
3205+
piece = "<eos>"
3206+
3207+
text = piece.encode("utf-8")
3208+
score = 0.0
3209+
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
3210+
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
3211+
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
3212+
score = tokenizer.tokenizer.sp_model.get_score(token_id)
3213+
3214+
if len(piece) == 0:
3215+
text = f"[PAD{token_id}]".encode("utf-8")
3216+
3217+
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
3218+
if piece in special_tokens:
3219+
# show special tokens in prompt
3220+
toktype = SentencePieceTokenTypes.USER_DEFINED
3221+
else:
3222+
toktype = SentencePieceTokenTypes.UNKNOWN
3223+
tokens.append(text)
3224+
scores.append(score)
3225+
toktypes.append(toktype)
3226+
continue
3227+
3228+
toktype = SentencePieceTokenTypes.NORMAL
3229+
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
3230+
toktype = SentencePieceTokenTypes.UNKNOWN
3231+
elif tokenizer.tokenizer.sp_model.is_control(token_id):
3232+
toktype = SentencePieceTokenTypes.CONTROL
3233+
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
3234+
toktype = SentencePieceTokenTypes.UNUSED
3235+
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
3236+
toktype = SentencePieceTokenTypes.BYTE
3237+
3238+
tokens.append(text)
3239+
scores.append(score)
3240+
toktypes.append(toktype)
3241+
3242+
self.gguf_writer.add_tokenizer_model("llama")
3243+
# glm3 needs prefix and suffix formatted as:
3244+
# prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
3245+
self.gguf_writer.add_tokenizer_pre("chatglm-spm")
3246+
self.gguf_writer.add_token_list(tokens)
3247+
self.gguf_writer.add_token_scores(scores)
3248+
self.gguf_writer.add_token_types(toktypes)
3249+
3250+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
3251+
special_vocab.add_to_gguf(self.gguf_writer)
3252+
3253+
@staticmethod
3254+
def token_bytes_to_string(b):
3255+
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
3256+
byte_encoder = bytes_to_unicode()
3257+
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
3258+
3259+
@staticmethod
3260+
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
3261+
parts = [bytes([b]) for b in token]
3262+
while True:
3263+
min_idx = None
3264+
min_rank = None
3265+
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
3266+
rank = mergeable_ranks.get(pair[0] + pair[1])
3267+
if rank is not None and (min_rank is None or rank < min_rank):
3268+
min_idx = i
3269+
min_rank = rank
3270+
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
3271+
break
3272+
assert min_idx is not None
3273+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
3274+
return parts
3275+
3276+
def set_vocab(self):
3277+
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
3278+
self.set_vocab_chatglm3()
3279+
return
3280+
3281+
dir_model = self.dir_model
3282+
hparams = self.hparams
3283+
tokens: list[str] = []
3284+
toktypes: list[int] = []
3285+
3286+
from transformers import AutoTokenizer
3287+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
3288+
vocab_size = hparams["padded_vocab_size"]
3289+
assert max(tokenizer.get_vocab().values()) < vocab_size
3290+
3291+
tokpre = self.get_vocab_base_pre(tokenizer)
3292+
3293+
merges = []
3294+
vocab = {}
3295+
mergeable_ranks = tokenizer.mergeable_ranks
3296+
for token, rank in mergeable_ranks.items():
3297+
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
3298+
if len(token) == 1:
3299+
continue
3300+
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
3301+
assert len(merged) >= 2 and len(merged) <= 7
3302+
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
3303+
3304+
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
3305+
added_vocab = tokenizer.get_added_vocab()
3306+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
3307+
3308+
for i in range(vocab_size):
3309+
if i not in reverse_vocab:
3310+
tokens.append(f"[PAD{i}]")
3311+
toktypes.append(gguf.TokenType.USER_DEFINED)
3312+
elif reverse_vocab[i] in added_vocab:
3313+
tokens.append(reverse_vocab[i])
3314+
if tokenizer.added_tokens_decoder[i].special:
3315+
toktypes.append(gguf.TokenType.CONTROL)
3316+
else:
3317+
toktypes.append(gguf.TokenType.USER_DEFINED)
3318+
else:
3319+
tokens.append(reverse_vocab[i])
3320+
toktypes.append(gguf.TokenType.NORMAL)
3321+
3322+
self.gguf_writer.add_tokenizer_model("gpt2")
3323+
self.gguf_writer.add_tokenizer_pre(tokpre)
3324+
self.gguf_writer.add_token_list(tokens)
3325+
self.gguf_writer.add_token_types(toktypes)
3326+
3327+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
3328+
special_vocab.merges = merges
3329+
# only add special tokens when they were not already loaded from config.json
3330+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
3331+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
3332+
# this one is usually not in config.json anyway
3333+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
3334+
special_vocab.add_to_gguf(self.gguf_writer)
3335+
3336+
def set_gguf_parameters(self):
3337+
self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
3338+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
3339+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
3340+
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
3341+
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
3342+
self.gguf_writer.add_embedding_length(n_embed)
3343+
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
3344+
self.gguf_writer.add_block_count(self.hparams["num_layers"])
3345+
self.gguf_writer.add_head_count(n_head)
3346+
self.gguf_writer.add_head_count_kv(n_head_kv)
3347+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
3348+
self.gguf_writer.add_file_type(self.ftype)
3349+
self.gguf_writer.add_rope_dimension_count(64)
3350+
self.gguf_writer.add_add_bos_token(False)
3351+
rope_freq = 10000
3352+
if "rope_ratio" in self.hparams:
3353+
rope_freq = rope_freq * self.hparams["rope_ratio"]
3354+
self.gguf_writer.add_rope_freq_base(rope_freq)
3355+
3356+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3357+
del bid # unused
3358+
3359+
if name.endswith(".rotary_pos_emb.inv_freq"):
3360+
return []
3361+
3362+
name = name.removeprefix("transformer.")
3363+
return [(self.map_tensor_name(name), data_torch)]
3364+
31783365
###### CONVERSION LOGIC ######
31793366

31803367

gguf-py/gguf/constants.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ class Tokenizer:
120120
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
121121
EOT_ID = "tokenizer.ggml.eot_token_id"
122122

123-
124123
#
125124
# recommended mapping of model tensor names for storage in gguf
126125
#
@@ -163,6 +162,7 @@ class MODEL_ARCH(IntEnum):
163162
OPENELM = auto()
164163
ARCTIC = auto()
165164
DEEPSEEK2 = auto()
165+
CHATGLM = auto()
166166
BITNET = auto()
167167
T5 = auto()
168168
JAIS = auto()
@@ -289,6 +289,7 @@ class MODEL_TENSOR(IntEnum):
289289
MODEL_ARCH.OPENELM: "openelm",
290290
MODEL_ARCH.ARCTIC: "arctic",
291291
MODEL_ARCH.DEEPSEEK2: "deepseek2",
292+
MODEL_ARCH.CHATGLM: "chatglm",
292293
MODEL_ARCH.BITNET: "bitnet",
293294
MODEL_ARCH.T5: "t5",
294295
MODEL_ARCH.JAIS: "jais",
@@ -924,6 +925,18 @@ class MODEL_TENSOR(IntEnum):
924925
MODEL_TENSOR.FFN_DOWN_SHEXP,
925926
MODEL_TENSOR.FFN_UP_SHEXP,
926927
],
928+
MODEL_ARCH.CHATGLM : [
929+
MODEL_TENSOR.TOKEN_EMBD,
930+
MODEL_TENSOR.ROPE_FREQS,
931+
MODEL_TENSOR.OUTPUT_NORM,
932+
MODEL_TENSOR.OUTPUT,
933+
MODEL_TENSOR.ATTN_NORM,
934+
MODEL_TENSOR.ATTN_QKV,
935+
MODEL_TENSOR.ATTN_OUT,
936+
MODEL_TENSOR.FFN_NORM,
937+
MODEL_TENSOR.FFN_DOWN,
938+
MODEL_TENSOR.FFN_UP,
939+
],
927940
MODEL_ARCH.BITNET: [
928941
MODEL_TENSOR.ATTN_Q,
929942
MODEL_TENSOR.ATTN_K,
@@ -1020,6 +1033,9 @@ class MODEL_TENSOR(IntEnum):
10201033
MODEL_TENSOR.ROPE_FREQS,
10211034
MODEL_TENSOR.ATTN_ROT_EMBD,
10221035
],
1036+
MODEL_ARCH.CHATGLM: [
1037+
MODEL_TENSOR.ROPE_FREQS,
1038+
],
10231039
}
10241040

10251041
#

gguf-py/gguf/tensor_mapping.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TensorNameMap:
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
2626
"transformer.in_out_embed", # Grok
27+
"embedding.word_embeddings", # chatglm
2728
"transformer.token_embeddings", # openelm
2829
"shared", # t5
2930
),
@@ -55,6 +56,7 @@ class TensorNameMap:
5556
"output", # llama-pth bloom internlm2
5657
"word_embeddings_for_head", # persimmon
5758
"lm_head.linear", # phi2
59+
"output_layer", # chatglm
5860
),
5961

6062
# Output norm
@@ -71,12 +73,14 @@ class TensorNameMap:
7173
"model.norm_f", # mamba-qbert
7274
"backbone.norm_f", # mamba
7375
"transformer.rms_norm", # Grok
76+
"encoder.final_layernorm", # chatglm
7477
"transformer.norm", # openelm
7578
),
7679

7780
# Rope frequencies
7881
MODEL_TENSOR.ROPE_FREQS: (
7982
"rope.freqs", # llama-pth
83+
"rotary_pos_emb.inv_freq", # chatglm
8084
),
8185
}
8286

@@ -101,6 +105,7 @@ class TensorNameMap:
101105
"backbone.layers.{bid}.norm", # mamba
102106
"transformer.decoder_layer.{bid}.rms_norm", # Grok
103107
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
108+
"encoder.layers.{bid}.input_layernorm", # chatglm
104109
"transformer.layers.{bid}.attn_norm", # openelm
105110
),
106111

@@ -124,6 +129,7 @@ class TensorNameMap:
124129
"transformer.h.{bid}.mixer.Wqkv", # phi2
125130
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
126131
"model.layers.{bid}.self_attn.qkv_proj", # phi3
132+
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
127133
"transformer.layers.{bid}.attn.qkv_proj", # openelm
128134
),
129135

@@ -135,7 +141,7 @@ class TensorNameMap:
135141
"transformer.h.{bid}.attn.q_proj", # gpt-j
136142
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
137143
"model.layers.{bid}.attention.wq", # internlm2
138-
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
144+
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
139145
),
140146

141147
# Attention key
@@ -147,7 +153,7 @@ class TensorNameMap:
147153
"transformer.h.{bid}.attn.k", # refact
148154
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
149155
"model.layers.{bid}.attention.wk", # internlm2
150-
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
156+
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
151157
),
152158

153159
# Attention value
@@ -182,6 +188,7 @@ class TensorNameMap:
182188
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
183189
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
184190
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
191+
"encoder.layers.{bid}.self_attention.dense", # chatglm
185192
"transformer.layers.{bid}.attn.out_proj", # openelm
186193
),
187194

@@ -218,6 +225,7 @@ class TensorNameMap:
218225
"h.{bid}.ln_2", # gpt2
219226
"model.layers.{bid}.ffn_norm", # internlm2
220227
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
228+
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
221229
"transformer.layers.{bid}.ffn_norm", # openelm
222230
),
223231

@@ -268,6 +276,7 @@ class TensorNameMap:
268276
"model.layers.{bid}.mlp.c_fc", # starcoder2
269277
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
270278
"model.layers.{bid}.residual_mlp.w3", # arctic
279+
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
271280
),
272281

273282
MODEL_TENSOR.FFN_UP_EXP: (
@@ -337,6 +346,7 @@ class TensorNameMap:
337346
"transformer.layers.{bid}.ffn.proj_2", # openelm
338347
"model.layers.{bid}.residual_mlp.w2", # arctic
339348
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
349+
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
340350
),
341351

342352
MODEL_TENSOR.FFN_DOWN_EXP: (

include/llama.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ extern "C" {
8888
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
8989
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
9090
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
91-
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
92-
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
91+
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
92+
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
93+
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
94+
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
9395
};
9496

9597
// note: these values should be synchronized with ggml_rope

0 commit comments

Comments
 (0)