Skip to content

Commit 37c746d

Browse files
simonJJJggerganov
andauthored
llama : add Qwen support (ggml-org#4281)
* enable qwen to llama.cpp * llama : do not GPU split bias tensors --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 880f579 commit 37c746d

File tree

5 files changed

+372
-9
lines changed

5 files changed

+372
-9
lines changed

convert-hf-to-gguf.py

+130-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
from enum import IntEnum
1212
from pathlib import Path
13-
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
13+
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
1414

1515
import numpy as np
1616
import torch
@@ -168,6 +168,8 @@ def from_model_architecture(model_architecture):
168168
return PersimmonModel
169169
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
170170
return StableLMModel
171+
if model_architecture == "QWenLMHeadModel":
172+
return QwenModel
171173
return Model
172174

173175
def _is_model_safetensors(self) -> bool:
@@ -203,6 +205,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
203205
return gguf.MODEL_ARCH.PERSIMMON
204206
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
205207
return gguf.MODEL_ARCH.STABLELM
208+
if arch == "QWenLMHeadModel":
209+
return gguf.MODEL_ARCH.QWEN
206210

207211
raise NotImplementedError(f'Architecture "{arch}" not supported!')
208212

@@ -832,6 +836,131 @@ def set_gguf_parameters(self):
832836
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
833837
self.gguf_writer.add_layer_norm_eps(1e-5)
834838

839+
840+
class QwenModel(Model):
841+
@staticmethod
842+
def token_bytes_to_string(b):
843+
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
844+
byte_encoder = bytes_to_unicode()
845+
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
846+
847+
@staticmethod
848+
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]:
849+
parts = [bytes([b]) for b in token]
850+
while True:
851+
min_idx = None
852+
min_rank = None
853+
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
854+
rank = mergeable_ranks.get(pair[0] + pair[1])
855+
if rank is not None and (min_rank is None or rank < min_rank):
856+
min_idx = i
857+
min_rank = rank
858+
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
859+
break
860+
assert min_idx is not None
861+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
862+
return parts
863+
864+
def set_vocab(self):
865+
dir_model = self.dir_model
866+
hparams = self.hparams
867+
tokens: list[bytearray] = []
868+
toktypes: list[int] = []
869+
870+
from transformers import AutoTokenizer # type: ignore[attr-defined]
871+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
872+
vocab_size = hparams["vocab_size"]
873+
assert max(tokenizer.get_vocab().values()) < vocab_size
874+
875+
merges = []
876+
vocab = {}
877+
mergeable_ranks = tokenizer.mergeable_ranks
878+
for token, rank in mergeable_ranks.items():
879+
vocab[self.token_bytes_to_string(token)] = rank
880+
if len(token) == 1:
881+
continue
882+
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
883+
assert len(merged) == 2
884+
merges.append(' '.join(map(self.token_bytes_to_string, merged)))
885+
886+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
887+
added_vocab = tokenizer.special_tokens
888+
889+
for i in range(vocab_size):
890+
if i not in reverse_vocab:
891+
pad_token = f"[PAD{i}]".encode("utf-8")
892+
tokens.append(bytearray(pad_token))
893+
toktypes.append(gguf.TokenType.USER_DEFINED)
894+
elif reverse_vocab[i] in added_vocab:
895+
tokens.append(reverse_vocab[i])
896+
toktypes.append(gguf.TokenType.CONTROL)
897+
else:
898+
tokens.append(reverse_vocab[i])
899+
toktypes.append(gguf.TokenType.NORMAL)
900+
901+
self.gguf_writer.add_tokenizer_model("gpt2")
902+
self.gguf_writer.add_token_list(tokens)
903+
self.gguf_writer.add_token_types(toktypes)
904+
905+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
906+
special_vocab.merges = merges
907+
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
908+
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
909+
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
910+
special_vocab.add_to_gguf(self.gguf_writer)
911+
912+
def set_gguf_parameters(self):
913+
self.gguf_writer.add_name("Qwen")
914+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
915+
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
916+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
917+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
918+
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
919+
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
920+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
921+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
922+
923+
def write_tensors(self):
924+
block_count = self.hparams["num_hidden_layers"]
925+
model_kv = dict(self.get_tensors())
926+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
927+
for name, data_torch in model_kv.items():
928+
# we don't need these
929+
if name.endswith(".rotary_emb.inv_freq"):
930+
continue
931+
932+
old_dtype = data_torch.dtype
933+
934+
# convert any unsupported data types to float32
935+
if data_torch.dtype not in (torch.float16, torch.float32):
936+
data_torch = data_torch.to(torch.float32)
937+
938+
data = data_torch.squeeze().numpy()
939+
940+
# map tensor names
941+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
942+
if new_name is None:
943+
print(f"Can not map tensor {name!r}")
944+
sys.exit()
945+
946+
n_dims = len(data.shape)
947+
data_dtype = data.dtype
948+
949+
# if f32 desired, convert any float16 to float32
950+
if self.ftype == 0 and data_dtype == np.float16:
951+
data = data.astype(np.float32)
952+
953+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
954+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
955+
data = data.astype(np.float32)
956+
957+
# if f16 desired, convert any float32 2-dim weight tensors to float16
958+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
959+
data = data.astype(np.float16)
960+
961+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
962+
self.gguf_writer.add_tensor(new_name, data)
963+
835964
###### CONVERSION LOGIC ######
836965

837966

gguf-py/gguf/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class MODEL_ARCH(IntEnum):
9292
BERT = auto()
9393
BLOOM = auto()
9494
STABLELM = auto()
95+
QWEN = auto()
9596

9697

9798
class MODEL_TENSOR(IntEnum):
@@ -132,6 +133,7 @@ class MODEL_TENSOR(IntEnum):
132133
MODEL_ARCH.BERT: "bert",
133134
MODEL_ARCH.BLOOM: "bloom",
134135
MODEL_ARCH.STABLELM: "stablelm",
136+
MODEL_ARCH.QWEN: "qwen",
135137
}
136138

137139
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -317,6 +319,20 @@ class MODEL_TENSOR(IntEnum):
317319
MODEL_TENSOR.FFN_DOWN,
318320
MODEL_TENSOR.FFN_UP,
319321
],
322+
MODEL_ARCH.QWEN: [
323+
MODEL_TENSOR.TOKEN_EMBD,
324+
MODEL_TENSOR.OUTPUT_NORM,
325+
MODEL_TENSOR.OUTPUT,
326+
MODEL_TENSOR.ROPE_FREQS,
327+
MODEL_TENSOR.ATTN_NORM,
328+
MODEL_TENSOR.ATTN_QKV,
329+
MODEL_TENSOR.ATTN_OUT,
330+
MODEL_TENSOR.ATTN_ROT_EMBD,
331+
MODEL_TENSOR.FFN_NORM,
332+
MODEL_TENSOR.FFN_GATE,
333+
MODEL_TENSOR.FFN_DOWN,
334+
MODEL_TENSOR.FFN_UP,
335+
],
320336
MODEL_ARCH.GPT2: [
321337
# TODO
322338
],
@@ -336,6 +352,10 @@ class MODEL_TENSOR(IntEnum):
336352
MODEL_ARCH.PERSIMMON: [
337353
MODEL_TENSOR.ROPE_FREQS,
338354
],
355+
MODEL_ARCH.QWEN: [
356+
MODEL_TENSOR.ROPE_FREQS,
357+
MODEL_TENSOR.ATTN_ROT_EMBD,
358+
],
339359
}
340360

341361
#

gguf-py/gguf/tensor_mapping.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf
@@ -38,7 +38,7 @@ class TensorNameMap:
3838
# Output
3939
MODEL_TENSOR.OUTPUT: (
4040
"embed_out", # gptneox
41-
"lm_head", # gpt2 mpt falcon llama-hf baichuan
41+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
4242
"output", # llama-pth bloom
4343
"word_embeddings_for_head", # persimmon
4444
),
@@ -51,7 +51,7 @@ class TensorNameMap:
5151
"norm", # llama-pth
5252
"embeddings.LayerNorm", # bert
5353
"transformer.norm_f", # mpt
54-
"ln_f", # refact bloom
54+
"ln_f", # refact bloom qwen
5555
"language_model.encoder.final_layernorm", # persimmon
5656
),
5757

@@ -65,7 +65,7 @@ class TensorNameMap:
6565
# Attention norm
6666
MODEL_TENSOR.ATTN_NORM: (
6767
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
68-
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
68+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
6969
"transformer.blocks.{bid}.norm_1", # mpt
7070
"transformer.h.{bid}.input_layernorm", # falcon7b
7171
"h.{bid}.input_layernorm", # bloom
@@ -85,7 +85,7 @@ class TensorNameMap:
8585
# Attention query-key-value
8686
MODEL_TENSOR.ATTN_QKV: (
8787
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
88-
"transformer.h.{bid}.attn.c_attn", # gpt2
88+
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
8989
"transformer.blocks.{bid}.attn.Wqkv", # mpt
9090
"transformer.h.{bid}.self_attention.query_key_value", # falcon
9191
"h.{bid}.self_attention.query_key_value", # bloom
@@ -119,7 +119,7 @@ class TensorNameMap:
119119
# Attention output
120120
MODEL_TENSOR.ATTN_OUT: (
121121
"gpt_neox.layers.{bid}.attention.dense", # gptneox
122-
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
122+
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
123123
"transformer.blocks.{bid}.attn.out_proj", # mpt
124124
"transformer.h.{bid}.self_attention.dense", # falcon
125125
"h.{bid}.self_attention.dense", # bloom
@@ -139,7 +139,7 @@ class TensorNameMap:
139139
# Feed-forward norm
140140
MODEL_TENSOR.FFN_NORM: (
141141
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
142-
"transformer.h.{bid}.ln_2", # gpt2 refact
142+
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
143143
"h.{bid}.post_attention_layernorm", # bloom
144144
"transformer.blocks.{bid}.norm_2", # mpt
145145
"model.layers.{bid}.post_attention_layernorm", # llama-hf
@@ -161,18 +161,20 @@ class TensorNameMap:
161161
"encoder.layer.{bid}.intermediate.dense", # bert
162162
"transformer.h.{bid}.mlp.fc_in", # gpt-j
163163
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
164+
"transformer.h.{bid}.mlp.w1", # qwen
164165
),
165166

166167
# Feed-forward gate
167168
MODEL_TENSOR.FFN_GATE: (
168169
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
169170
"layers.{bid}.feed_forward.w1", # llama-pth
171+
"transformer.h.{bid}.mlp.w2", # qwen
170172
),
171173

172174
# Feed-forward down
173175
MODEL_TENSOR.FFN_DOWN: (
174176
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
175-
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
177+
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
176178
"transformer.blocks.{bid}.ffn.down_proj", # mpt
177179
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
178180
"h.{bid}.mlp.dense_4h_to_h", # bloom

0 commit comments

Comments
 (0)