diff --git a/.gitignore b/.gitignore index e1b2bc92f..f4a00b121 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ cache_dir example-models *.onnx *.onnx.data +__pycache__ benchmark/python/*.csv !test/test_models/hf-internal-testing/ diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 4dc30e7e2..f39c1d941 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -2,6 +2,19 @@ This folder contains the model builder tool, which greatly accelerates creating optimized and quantized ONNX models that run with ONNX Runtime GenAI. +# Contents + - [Current Support](#current-support) + - [Usage](#usage) + - [Full Usage](#full-usage) + - [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face) + - [Original PyTorch Model from Disk](#original-pytorch-model-from-disk) + - [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model) + - [GGUF Model](#gguf-model) + - [Extra Options](#extra-options) + - [Unit Testing Models](#unit-testing-models) + - [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly) + - [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool) + ## Current Support The tool currently supports the following model architectures. @@ -22,44 +35,54 @@ python3 -m onnxruntime_genai.models.builder --help python3 builder.py --help ``` -### Original Model From Hugging Face +### Original PyTorch Model from Hugging Face This scenario is where your PyTorch model is not downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_save_hf_files # From source: -python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_save_hf_files ``` -### Original Model From Disk +### Original PyTorch Model from Disk This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved # From source: -python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved ``` -### Customized or Finetuned Model +### Customized or Finetuned PyTorch Model This scenario is where your PyTorch model has been customized or finetuned for one of the currently supported model architectures and your model can be loaded in Hugging Face. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider + +# From source: +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider +``` + +### GGUF Model +This scenario is where your float16/float32 GGUF model is already on disk. +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -i path_to_gguf_file -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files # From source: -python3 builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider +python3 builder.py -m model_name -i path_to_gguf_file -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files ``` ### Extra Options This scenario is for when you want to have control over some specific settings. The below example shows how you can pass key-value arguments to `--extra_options`. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options filename=decoder.onnx # From source: -python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options filename=decoder.onnx ``` To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above. @@ -83,10 +106,10 @@ tokenizer.save_pretrained(cache_dir) This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4 # From source: -python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4 ``` #### Option 2: Edit the config.json file on disk and then run the model builder tool @@ -97,8 +120,8 @@ python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execu ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved # From source: -python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved ``` \ No newline at end of file diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1f063d970..c8ccd82d5 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -100,11 +100,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } # RotaryEmbedding-specific variables + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 self.rotemb_attrs = { - "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings - "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) - "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) - "theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch + "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings + "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings + "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) @@ -545,7 +547,8 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): if self.rotemb_attrs["create_rotary_embedding_caches"]: if not hasattr(rotemb, "cos_cached"): # Create cos/sin caches if not already created - inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, self.head_size, 2, dtype=torch.int64).float() / self.head_size)) + dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size) + inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) t = torch.arange(self.context_length, dtype=torch.int64).type_as(inv_freq) freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -789,7 +792,7 @@ def make_layer(self, layer_id, layer): # Norm after last decoder layer of model (last layer --> norm) self.layernorm_attrs["last_layernorm"] = True - def make_model(self): + def make_model(self, input_path): # Make inputs and outputs to ONNX model self.make_inputs_and_outputs() @@ -802,14 +805,22 @@ def make_model(self): # 4D causal attention mask self.make_attention_mask_reformatting() - # Load PyTorch model - extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True} - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs) + # Load weights of original model + if input_path.endswith(".gguf"): + # Load GGUF model + from gguf_model import GGUFModel + model = GGUFModel.from_pretrained(self.model_type, input_path, self.head_size, self.hidden_size, self.intermediate_size, self.num_attn_heads, self.num_kv_heads, self.vocab_size) + self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models + else: + # Load PyTorch model + extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True} + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs) - # Loop through PyTorch model and map each nn.Module to ONNX/ORT ops + # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 for module in model.modules(): - if isinstance(module, torch.nn.Embedding): + if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding): + # Checks (Hugging Face logic) or (GGUF logic) # Embedding layer print("Reading embedding layer") self.make_embedding(module.weight.detach().numpy()) @@ -822,7 +833,8 @@ def make_model(self): # SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm) print("Reading final norm") self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") - elif isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size: + elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head): + # Checks (Hugging Face logic) or (GGUF logic) # Language modeling head (SkipLayerNorm --> logits) print("Reading LM head") self.make_lm_head(module) @@ -830,9 +842,12 @@ def make_model(self): del model def has_final_norm(self, module, model): - norm = hasattr(model.model, "norm") and module == model.model.norm - final_layernorm = hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm - return norm or final_layernorm + # Hugging Face names + hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm + hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm + # GGUF names + gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm + return hf_norm or hf_final_layernorm or gguf_final_norm def make_attention_mask_reformatting(self): # Make nodes for the attention mask subgraphs that reformat the @@ -1362,11 +1377,15 @@ def parse_extra_options(kv_items): return kv_pairs -def create_model(model_name_or_path, output_dir, precision, execution_provider, cache_dir, **extra_options): +def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options): + # Create cache and output directories os.makedirs(output_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True) - extra_kwargs = {} if os.path.exists(model_name_or_path) else {"cache_dir": cache_dir, "use_auth_token": True} - config = AutoConfig.from_pretrained(model_name_or_path, **extra_kwargs) + + # Load model config + extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True} + hf_name = input_path if os.path.isdir(input_path) else model_name + config = AutoConfig.from_pretrained(hf_name, **extra_kwargs) # Set input/output precision of ONNX model io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 @@ -1381,19 +1400,19 @@ def create_model(model_name_or_path, output_dir, precision, execution_provider, elif config.architectures[0] == "PhiForCausalLM": onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) else: - raise NotImplementedError(f"The {model_name_or_path} model is not currently supported.") + raise NotImplementedError(f"The {hf_name} model is not currently supported.") # Make ONNX model - onnx_model.make_model() + onnx_model.make_model(input_path) # Save ONNX model onnx_model.save_model(output_dir) # Make GenAI config - onnx_model.make_genai_config(model_name_or_path, extra_kwargs, output_dir) + onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir) # Copy Hugging Face processing files to output folder - onnx_model.save_processing(model_name_or_path, extra_kwargs, output_dir) + onnx_model.save_processing(hf_name, extra_kwargs, output_dir) def get_args(): @@ -1401,9 +1420,22 @@ def get_args(): parser.add_argument( "-m", - "--model_name_or_path", - required=True, - help="Model name in Hugging Face or path to folder on disk containing the Hugging Face config, model, tokenizer, etc.", + "--model_name", + required=False, + default=None, + help="Model name in Hugging Face. Do not use if providing an input path to a Hugging Face directory in -i/--input.", + ) + + parser.add_argument( + "-i", + "--input", + required=False, + default="", + help=textwrap.dedent("""\ + Input model source. Currently supported options are: + hf_path: Path to folder on disk containing the Hugging Face config, model, tokenizer, etc. + gguf_path: Path to float16/float32 GGUF file on disk containing the GGUF model + """), ) parser.add_argument( @@ -1435,7 +1467,7 @@ def get_args(): required=False, type=str, default=os.path.join('.', 'cache_dir'), - help="Model cache directory (if providing model name and not folder path)", + help="Cache directory for Hugging Face files and temporary ONNX external data files", ) parser.add_argument( @@ -1465,4 +1497,4 @@ def get_args(): if __name__ == '__main__': args = get_args() extra_options = parse_extra_options(args.extra_options) - create_model(args.model_name_or_path, args.output, args.precision, args.execution_provider, args.cache_dir, **extra_options) + create_model(args.model_name, args.input, args.output, args.precision, args.execution_provider, args.cache_dir, **extra_options) diff --git a/src/python/py/models/gguf_model.py b/src/python/py/models/gguf_model.py new file mode 100644 index 000000000..96573b04a --- /dev/null +++ b/src/python/py/models/gguf_model.py @@ -0,0 +1,253 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +""" +A set of Python classes to mimic Hugging Face's PyTorch models with GGUF weights that are already in NumPy. + +The goal is for `GGUFModel` and a Hugging Face model produced by `AutoModel.from_pretrained(...)` +to share the same attributes so that the original Hugging Face --> ONNX code remains the same +no matter where the weights actually come from. +""" + +from functools import reduce +from gguf.gguf_reader import GGUFReader + +import re + + +class GGUFTensor: + def __init__(self): + self.tensor = None + + def detach(self): + """ + No-op operation since NumPy tensors are on CPU + """ + return self + + def numpy(self): + """ + Return tensor, which is already stored as a NumPy tensor + """ + return self.tensor + + +class GGUFTensorModule: + def __init__(self): + self.weight = GGUFTensor() + self.bias = None + + def add_bias(self): + self.bias = GGUFTensor() + + +class GGUFAttention: + def __init__(self): + self.q_proj = GGUFTensorModule() + self.k_proj = GGUFTensorModule() + self.v_proj = GGUFTensorModule() + self.o_proj = GGUFTensorModule() + self.rotary_emb = GGUFTensorModule() + + +class GGUFMLP: + def __init__(self): + self.gate_proj = GGUFTensorModule() + self.up_proj = GGUFTensorModule() + self.down_proj = GGUFTensorModule() + self.fc1 = GGUFTensorModule() + self.fc2 = GGUFTensorModule() + + +class GGUFDecoderLayer: + def __init__(self, layer_id): + self.layer_id = layer_id + self.input_layernorm = GGUFTensorModule() + self.self_attn = GGUFAttention() + self.post_attention_layernorm = GGUFTensorModule() + self.mlp = GGUFMLP() + + +class GGUFModel: + def __init__(self, input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size): + # Load GGUF model and read its info + reader = GGUFReader(input_path) + + self.embedding = GGUFTensorModule() + self.final_norm = GGUFTensorModule() + self.lm_head = GGUFTensorModule() + self.layers = [] + + layer_id = 0 + module = GGUFDecoderLayer(layer_id) + for tensor in sorted(reader.tensors, key=lambda t: t.name): + name = tensor.name + + if name == "token_embd.weight": + # Remove tensor data's padding via `reduce` when GGUF model's vocab size is larger than the config's vocab size + embedding_shape = [vocab_size, hidden_size] + self.embedding.weight.tensor = tensor.data[ : reduce(lambda x, y: x*y, embedding_shape)].reshape(embedding_shape) + elif name == "output_norm.weight": + self.final_norm.weight.tensor = tensor.data + elif name == "output_norm.bias": + self.final_norm.add_bias() + self.final_norm.bias.tensor = tensor.data + elif name == "output.weight": + lm_head_shape = [vocab_size, hidden_size] + self.lm_head.weight.tensor = tensor.data.reshape(lm_head_shape) + elif name == "output.bias": + self.lm_head.add_bias() + self.lm_head.bias.tensor = tensor.data + else: + curr_layer_id = int(name.split(".")[1]) + if curr_layer_id != layer_id: + # Add layer to list of modules + self.layers.append(module) + layer_id = curr_layer_id + module = GGUFDecoderLayer(layer_id) + + # Map weights and biases of norm, attention, and feed-forward network + # Graph order is attn_norm --> attn_q/k/v --> attn_output --> ffn_norm --> ffn_gate/up --> >ffn_down + if bool(re.match(r"^blk\.\d+\.attn_norm\.weight$", name)): + # blk.layer_id.attn_norm.weight + module.input_layernorm.weight.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.attn_norm\.bias$", name)): + # blk.layer_id.attn_norm.bias + module.input_layernorm.add_bias() + module.input_layernorm.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.attn_q\.weight$", name)): + # blk.layer_id.attn_q.weight + q_shape = [head_size * num_attn_heads, hidden_size] + module.self_attn.q_proj.weight.tensor = tensor.data.reshape(q_shape) + elif bool(re.match(r"^blk\.\d+\.attn_q\.bias$", name)): + # blk.layer_id.attn_q.bias + module.self_attn.q_proj.add_bias() + module.self_attn.q_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.attn_k\.weight$", name)): + # blk.layer_id.attn_k.weight + k_shape = [head_size * num_kv_heads, hidden_size] + module.self_attn.k_proj.weight.tensor = tensor.data.reshape(k_shape) + elif bool(re.match(r"^blk\.\d+\.attn_k\.bias$", name)): + # blk.layer_id.attn_k.bias + module.self_attn.k_proj.add_bias() + module.self_attn.k_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.attn_v\.weight$", name)): + # blk.layer_id.attn_v.weight + v_shape = [head_size * num_kv_heads, hidden_size] + module.self_attn.v_proj.weight.tensor = tensor.data.reshape(v_shape) + elif bool(re.match(r"^blk\.\d+\.attn_v\.bias$", name)): + # blk.layer_id.attn_v.bias + module.self_attn.v_proj.add_bias() + module.self_attn.v_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.attn_output\.weight$", name)): + # blk.layer_id.attn_output.weight + o_shape = [hidden_size, head_size * num_attn_heads] + module.self_attn.o_proj.weight.tensor = tensor.data.reshape(o_shape) + elif bool(re.match(r"^blk\.\d+\.attn_output\.bias$", name)): + # blk.layer_id.attn_output.bias + module.self_attn.o_proj.add_bias() + module.self_attn.o_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.ffn_norm\.weight$", name)): + # blk.layer_id.ffn_norm.weight + module.post_attention_layernorm.weight.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.ffn_norm\.bias$", name)): + # blk.layer_id.ffn_norm.bias + module.post_attention_layernorm.add_bias() + module.post_attention_layernorm.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.ffn_gate\.weight$", name)): + # blk.layer_id.ffn_gate.weight + gate_shape = [intermediate_size, hidden_size] + module.mlp.gate_proj.weight.tensor = tensor.data.reshape(gate_shape) + elif bool(re.match(r"^blk\.\d+\.ffn_gate\.bias$", name)): + # blk.layer_id.ffn_gate.bias + module.mlp.gate_proj.add_bias() + module.mlp.gate_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.ffn_up\.weight$", name)): + # blk.layer_id.ffn_up.weight + up_shape = [intermediate_size, hidden_size] + module.mlp.up_proj.weight.tensor = tensor.data.reshape(up_shape) + elif bool(re.match(r"^blk\.\d+\.ffn_up\.bias$", name)): + # blk.layer_id.ffn_up.bias + module.mlp.up_proj.add_bias() + module.mlp.up_proj.bias.tensor = tensor.data + elif bool(re.match(r"^blk\.\d+\.ffn_down\.weight$", name)): + # blk.layer_id.ffn_down.weight + down_shape = [hidden_size, intermediate_size] + module.mlp.down_proj.weight.tensor = tensor.data.reshape(down_shape) + elif bool(re.match(r"^blk\.\d+\.ffn_down\.bias$", name)): + # blk.layer_id.ffn_down.bias + module.mlp.down_proj.add_bias() + module.mlp.down_proj.bias.tensor = tensor.data + else: + raise NotImplementedError(f"{name} in your GGUF model is not recognized") + + # Append final layer to list of layers + self.layers.append(module) + + # Set LM head weights + biases if not already set + if self.lm_head.weight.tensor is None: + # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) + self.lm_head.weight.tensor = self.embedding.weight.tensor + if self.lm_head.bias is not None: + self.lm_head.bias.tensor = self.embedding.bias.tensor + + # Sort list of layers by layer id + self.layers.sort(key=lambda m: m.layer_id) + + def modules(self): + """ + Return list of modules in GGUF model in order of appearance in the model + """ + return [self.embedding] + self.layers + [self.final_norm, self.lm_head] + + def undo_permute(self, head_size, hidden_size, num_attn_heads, num_kv_heads): + """ + Undo `permute` operation by GGUF to get Hugging Face format + For GGUF models produced by `convert.py` (e.g. LLaMA, Mistral) + """ + for module in self.layers: + q_shape = [head_size * num_attn_heads, hidden_size] + module.self_attn.q_proj.weight.tensor = module.self_attn.q_proj.weight.tensor.flatten().reshape(num_attn_heads, q_shape[0] // num_attn_heads // 2, 2, *q_shape[1:]).swapaxes(1, 2).reshape(q_shape) + + k_shape = [head_size * num_kv_heads, hidden_size] + module.self_attn.k_proj.weight.tensor = module.self_attn.k_proj.weight.tensor.flatten().reshape(num_kv_heads, k_shape[0] // num_kv_heads // 2, 2, *k_shape[1:]).swapaxes(1, 2).reshape(k_shape) + + def swap_mlp_types(self): + """ + Switch from using the default `up_proj`/`down_proj` attributes to the `fc1`/`fc2` attributes respectively + For GGUF models such as Phi-2 + """ + # Convert ffn_up (up_proj in Hugging Face model) to fc1 + # Convert ffn_down (down_proj in Hugging Face model) to fc2 + for module in self.layers: + module.mlp.fc1, module.mlp.up_proj = module.mlp.up_proj, module.mlp.fc1 + module.mlp.fc2, module.mlp.down_proj = module.mlp.down_proj, module.mlp.fc2 + + @staticmethod + def from_pretrained(model_type, input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size): + """ + Create GGUF models with the same attribute structures as Hugging Face's PyTorch models. + Also performs any pre-processing and post-processing to the GGUF models to ensure the + weights are the same as the PyTorch models. + """ + if model_type == "GemmaForCausalLM": + # convert-hf-to-gguf.py + model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) + elif model_type == "LlamaForCausalLM": + # convert.py + model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) + model.undo_permute(head_size, hidden_size, num_attn_heads, num_kv_heads) + elif model_type == "MistralForCausalLM": + # convert.py + model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) + model.undo_permute(head_size, hidden_size, num_attn_heads, num_kv_heads) + elif model_type == "PhiForCausalLM": + # convert-hf-to-gguf.py + model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) + model.swap_mlp_types() + else: + raise NotImplementedError(f"The {model_type} model is not currently supported.") + + return model \ No newline at end of file