Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 34 additions & 91 deletions convert-new.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,6 @@ class UnquantizedDataType:

DataType = Union[UnquantizedDataType]

DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
DT_F32: 0,
DT_F16: 1,
}

FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}

DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16),
DT_F16: np.dtype(np.float16),
Expand All @@ -78,31 +70,6 @@ def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
else:
raise ValueError(self)

# TODO: this is LLaMA specific
def make_tensors_list() -> List[str]:
ret = [
'tok_embeddings.weight',
'norm.weight',
'output.weight',
]
for i in range(80): # maximum number of layer
ret += [
f'layers.{i}.attention.wq.weight',
f'layers.{i}.attention.wk.weight',
f'layers.{i}.attention.wv.weight',
f'layers.{i}.attention.wo.weight',
f'layers.{i}.attention_norm.weight',
f'layers.{i}.feed_forward.w1.weight',
f'layers.{i}.feed_forward.w2.weight',
f'layers.{i}.feed_forward.w3.weight',
f'layers.{i}.ffn_norm.weight',
]
return ret

# TODO: this should be generalized for non-LLaMA models
TENSORS_LIST = make_tensors_list()
TENSORS_SET = set(TENSORS_LIST)

def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
Expand Down Expand Up @@ -533,34 +500,6 @@ def load() -> Tensor:
s[0] = s[0] // 3
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)

def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {}
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
out["norm.weight"] = model["model.norm.weight"]
out["output.weight"] = model["lm_head.weight"]

for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head_kv)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
else:
break

out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]

out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]
out[f"layers.{i}.feed_forward.w2.weight"] = model[f"model.layers.{i}.mlp.down_proj.weight"]
out[f"layers.{i}.feed_forward.w3.weight"] = model[f"model.layers.{i}.mlp.up_proj.weight"]

out[f"layers.{i}.attention_norm.weight"] = model[f"model.layers.{i}.input_layernorm.weight"]
out[f"layers.{i}.ffn_norm.weight"] = model[f"model.layers.{i}.post_attention_layernorm.weight"]
return out


# Functionality that simulates `torch.load` but where individual tensors are
# only loaded into memory on demand, not all at once.
Expand Down Expand Up @@ -750,8 +689,8 @@ class OutputFile:
def __init__(self, fname_out: Path) -> None:
self.gguf = gguf.GGUFWriter.open(fname_out)

def add_meta_arch(self, params: Params, file_type: GGMLFileType) -> None:
llm_arch = "llama"
def add_meta_arch(self, params: Params) -> None:
llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA]

self.gguf.add_architecture (llm_arch)
self.gguf.add_context_length (llm_arch, params.n_ctx)
Expand All @@ -763,13 +702,6 @@ def add_meta_arch(self, params: Params, file_type: GGMLFileType) -> None:
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)

#def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
# sname = name.encode('utf-8')
# self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type]))
# self.fout.write(struct.pack("i" * len(shape), *shape[::-1]))
# self.fout.write(sname)
# self.fout.seek((self.fout.tell() + 31) & -32)

def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = []
scores = []
Expand All @@ -794,17 +726,17 @@ def close(self) -> None:
@staticmethod
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
of = OutputFile(fname_out)
of.add_meta_arch(params, file_type=GGMLFileType.AllF32)
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
of.write_meta()
of.close()

@staticmethod
def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None:
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab)

of = OutputFile(fname_out)
of.add_meta_arch(params, file_type)
of.add_meta_arch(params)
of.add_meta_vocab(vocab)

def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
Expand All @@ -822,21 +754,39 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray:


def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
wq_type = model["layers.0.attention.wq.weight"].data_type
wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type

if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
return GGMLFileType.MostlyF16

name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}

raise Exception(f"Unexpected combination of types: {name_to_type}")


def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel:
if "lm_head.weight" in model:
model = convert_transformers_to_orig(model, params)
model = filter_and_sort_tensors(model)
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer)

return model
out: LazyModel = {}
for name, lazy_tensor in model.items():
name_new = name

if name in tmap:
name_new = tmap[name]
elif name.endswith(".weight") and name[:-7] in tmap:
name_new = tmap[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tmap:
name_new = tmap[name[:-5]] + ".bias"
else:
raise Exception(f"Unexpected tensor name: {name}")

out[name_new] = lazy_tensor

print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")

return out


def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
Expand Down Expand Up @@ -893,11 +843,6 @@ def load_some_model(path: Path) -> ModelPlus:
# Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try GGML too, but with lower priority, since if both a non-GGML
# model and a GGML model exist in the same directory, we assume the
# latter was converted from the former.
files = list(path.glob("ggml-model*.bin*"))
if not files:
raise Exception(f"Can't find model in directory {path}")
if len(files) > 1:
Expand All @@ -914,10 +859,6 @@ def load_some_model(path: Path) -> ModelPlus:
return model_plus


def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}


def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
Expand All @@ -937,8 +878,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"

print(f"Loading vocab file '{path}', type '{vocabtype}'")

added_tokens_path = path.parent / "added_tokens.json"
if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm":
Expand Down Expand Up @@ -1018,12 +961,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab = load_vocab(vocab_dir, args.vocabtype)

model = model_plus.model
model = do_necessary_conversions(model, params) # TODO: utilize gguf.get_tensor_name_map
model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map
output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, output_type)

OutputFile.write_all(outfile, params, output_type, model, vocab)
OutputFile.write_all(outfile, params, model, vocab)
print(f"Wrote {outfile}")


Expand Down
Loading