Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BERT embedding models #5423

Merged
merged 21 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[flake8]
max-line-length = 125
ignore = W503
94 changes: 94 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def from_model_architecture(model_architecture):
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
if model_architecture == "BertModel":
return BertModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -264,6 +266,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
if arch == "BertModel":
return gguf.MODEL_ARCH.BERT

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1629,6 +1633,96 @@ def write_tensors(self):
self.post_write_tensors(tensor_map, name, data_torch)


class BertModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"]

def set_gguf_parameters(self):
# TODO(cebtenzzre): merge with parent class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: resolve this before merge

Copy link
Contributor

@iacore iacore Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you... have you forgotten about this...

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_file_type(self.ftype)

def set_vocab(self):
path = self.dir_model
added_tokens_path = self.dir_model if self.dir_model.exists() else None

# use huggingface vocab to get all tokens
vocab = HfVocab(path, added_tokens_path)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size

# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
n_token_types = len(set(toktypes))
self.gguf_writer.add_token_type_count(n_token_types)

# convert to phantom space vocab
def phantom(tok, typ):
if tok.startswith(b"[") and tok.endswith(b"]"):
return tok
if tok.startswith(b"##"):
return tok[2:]
return b"\xe2\x96\x81" + tok
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]

# set up bos and eos tokens (cls and sep)
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)

# add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

# handle special tokens
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
tensors = dict(self.get_tensors())
for name, data_torch in tensors.items():
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
continue # we don't need these

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
new_dtype: type[np.floating[Any]]

if (
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
):
# if f16 desired, convert any float32 2-dim weight tensors to float16
new_dtype = np.float16
else:
# if f32 desired, convert any float16 to float32
new_dtype = np.float32

print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")

if data.dtype != new_dtype:
data = data.astype(new_dtype)

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


Expand Down
12 changes: 11 additions & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
}

const int n_embd = llama_n_embd(model);
const auto * embeddings = llama_get_embeddings(ctx);
auto * embeddings = llama_get_embeddings(ctx);

// l2-normalize embeddings
float norm = 0;
for (int i = 0; i < n_embd; i++) {
norm += embeddings[i] * embeddings[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; i++) {
embeddings[i] /= norm;
}

for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);
Expand Down
43 changes: 25 additions & 18 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Attention:
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand All @@ -60,22 +61,23 @@ class Rope:
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"

class Tokenizer:
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type"
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type"
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"


#
Expand Down Expand Up @@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_OUT = auto()
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_OUT_NORM = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto()
FFN_NORM = auto()
Expand All @@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
FFN_UP_EXP = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -178,6 +182,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
Expand All @@ -187,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -262,17 +268,18 @@ class MODEL_TENSOR(IntEnum):
],
MODEL_ARCH.BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ def add_layer_norm_eps(self, value: float) -> None:
def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)

def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)

Expand Down Expand Up @@ -387,6 +390,9 @@ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[by
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)

def add_token_type_count(self, value: int) -> None:
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)

def add_token_scores(self, scores: Sequence[float]) -> None:
self.add_array(Keys.Tokenizer.SCORES, scores)

Expand Down
13 changes: 10 additions & 3 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TensorNameMap:
# Normalization of token embeddings
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
),

# Position embeddings
Expand All @@ -54,7 +55,6 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
Expand All @@ -79,7 +79,6 @@ class TensorNameMap:
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2
Expand Down Expand Up @@ -155,6 +154,11 @@ class TensorNameMap:
"model.layers.{bid}.attention.wo", # internlm2
),

# Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
),

# Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
Expand All @@ -171,7 +175,6 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth
"encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2
Expand Down Expand Up @@ -266,6 +269,10 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
),

MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert
)
}

mapping: dict[str, tuple[MODEL_TENSOR, str]]
Expand Down
Loading
Loading