Skip to content

Commit

Permalink
Add GGUF to model builder tool (#138)
Browse files Browse the repository at this point in the history
### Description
This PR adds support for converting float16/float32 GGUF models to
optimized and quantized ONNX models via the model builder tool.

### Motivation and Context
[GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) is a
popular file format used in the
[`llama.cpp`](https://github.com/ggerganov/llama.cpp) project. The
project has multiple scripts to convert models to GGUF
([`convert.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert.py),
[`convert-hf-to-gguf.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py),
[`convert-llama-ggml-to-gguf.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert-llama-ggml-to-gguf.py),
etc).

The conversion scripts apply for specific model architectures only. For
the currently supported architectures in the model builder tool, these
are the corresponding conversion scripts.
- LLaMA: `convert.py`
- Mistral: `convert.py`
- Phi-2: `convert-hf-to-gguf.py`
- Gemma: `convert-hf-to-gguf.py`

Depending on the conversion scripts, the weights are also stored
differently.
- `convert.py`
[permutes](https://github.com/ggerganov/llama.cpp/blob/d5ab29757ebc59a30f03e408294ec20628a6374e/convert.py#L565)
the [Q projection and K projection
weights](https://github.com/ggerganov/llama.cpp/blob/d5ab29757ebc59a30f03e408294ec20628a6374e/convert.py#L1186-L1187)
before storing them
- `convert-hf-to-gguf.py` stores the weights in their [original
order](https://github.com/ggerganov/llama.cpp/blob/c29af7e2252d288f2ea58a7d437c1cb7c0abf160/gguf-py/gguf/gguf_writer.py#L244)

New model architectures that are added to the project appear to use
`convert-hf-to-gguf.py` for conversion now.

### Notes About Gemma Models

There are two ways to obtain GGUF versions of Gemma: 1) download the
PyTorch model from Hugging Face and use `convert-hf-to-gguf.py` to
convert or 2) download Google's released GGUF versions from Hugging
Face.

#### Converting Gemma from Hugging Face to GGUF

For the Gemma GGUF models created from conversion, a parity mismatch was
discovered in the LayerNorm weights when comparing the converted GGUF
models and the PyTorch models in Hugging Face. For more details on this
error and the fix for the parity mismatch, please refer to [this
PR](ggerganov/llama.cpp#5810) in the `llama.cpp`
project.

Users should run `convert-hf-to-gguf.py` again to obtain the right
LayerNorm weights in the Gemma GGUF models.

#### Released GGUF Versions of Gemma
The Gemma GGUF models released on Hugging Face have a vocab size of
256128, which matches the vocab size specified in the [official
paper](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf).
However, the Gemma PyTorch models released on Hugging Face have a [vocab
size of
256000](https://huggingface.co/google/gemma-2b/blob/9d067f00def958594aaa16b39a65b07d69ca655b/config.json#L26).

This difference affects the size of the embeddings. Upon further
examination, the embeddings in the released GGUF models are padded. When
the padding is removed, the embeddings in both the released GGUF models
and the released PyTorch models have the same size and have parity.

It is possible that the released GGUF models were converted from
internal checkpoints instead of the released PyTorch checkpoints. This
could explain why the embeddings have different sizes and why there are
still some parity mismatches in other weights between the released GGUF
models and the released PyTorch models.
  • Loading branch information
kunal-vaishnavi authored Mar 1, 2024
1 parent 995dd7c commit adec01f
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cache_dir
example-models
*.onnx
*.onnx.data
__pycache__
benchmark/python/*.csv

!test/test_models/hf-internal-testing/
Expand Down
53 changes: 38 additions & 15 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

This folder contains the model builder tool, which greatly accelerates creating optimized and quantized ONNX models that run with ONNX Runtime GenAI.

# Contents
- [Current Support](#current-support)
- [Usage](#usage)
- [Full Usage](#full-usage)
- [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face)
- [Original PyTorch Model from Disk](#original-pytorch-model-from-disk)
- [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model)
- [GGUF Model](#gguf-model)
- [Extra Options](#extra-options)
- [Unit Testing Models](#unit-testing-models)
- [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly)
- [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool)

## Current Support
The tool currently supports the following model architectures.

Expand All @@ -22,44 +35,54 @@ python3 -m onnxruntime_genai.models.builder --help
python3 builder.py --help
```

### Original Model From Hugging Face
### Original PyTorch Model from Hugging Face
This scenario is where your PyTorch model is not downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk).
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
# From source:
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
```

### Original Model From Disk
### Original PyTorch Model from Disk
This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk).
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
# From source:
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
```

### Customized or Finetuned Model
### Customized or Finetuned PyTorch Model
This scenario is where your PyTorch model has been customized or finetuned for one of the currently supported model architectures and your model can be loaded in Hugging Face.
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider
```

### GGUF Model
This scenario is where your float16/float32 GGUF model is already on disk.
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -i path_to_gguf_file -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files
# From source:
python3 builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
python3 builder.py -m model_name -i path_to_gguf_file -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files
```

### Extra Options
This scenario is for when you want to have control over some specific settings. The below example shows how you can pass key-value arguments to `--extra_options`.
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options filename=decoder.onnx
# From source:
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options filename=decoder.onnx
```
To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above.

Expand All @@ -83,10 +106,10 @@ tokenizer.save_pretrained(cache_dir)
This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers.
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4
# From source:
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4
```

#### Option 2: Edit the config.json file on disk and then run the model builder tool
Expand All @@ -97,8 +120,8 @@ python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execu

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
# From source:
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
```
86 changes: 59 additions & 27 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

# RotaryEmbedding-specific variables
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
self.rotemb_attrs = {
"create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings
"num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch
"create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings
"partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings
"num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch
}

# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
Expand Down Expand Up @@ -545,7 +547,8 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs):
if self.rotemb_attrs["create_rotary_embedding_caches"]:
if not hasattr(rotemb, "cos_cached"):
# Create cos/sin caches if not already created
inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, self.head_size, 2, dtype=torch.int64).float() / self.head_size))
dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size)
inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
t = torch.arange(self.context_length, dtype=torch.int64).type_as(inv_freq)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
Expand Down Expand Up @@ -789,7 +792,7 @@ def make_layer(self, layer_id, layer):
# Norm after last decoder layer of model (last layer --> norm)
self.layernorm_attrs["last_layernorm"] = True

def make_model(self):
def make_model(self, input_path):
# Make inputs and outputs to ONNX model
self.make_inputs_and_outputs()

Expand All @@ -802,14 +805,22 @@ def make_model(self):
# 4D causal attention mask
self.make_attention_mask_reformatting()

# Load PyTorch model
extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True}
model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs)
# Load weights of original model
if input_path.endswith(".gguf"):
# Load GGUF model
from gguf_model import GGUFModel
model = GGUFModel.from_pretrained(self.model_type, input_path, self.head_size, self.hidden_size, self.intermediate_size, self.num_attn_heads, self.num_kv_heads, self.vocab_size)
self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models
else:
# Load PyTorch model
extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True}
model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs)

# Loop through PyTorch model and map each nn.Module to ONNX/ORT ops
# Loop through model and map each module to ONNX/ORT ops
self.layer_id = 0
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding):
# Checks (Hugging Face logic) or (GGUF logic)
# Embedding layer
print("Reading embedding layer")
self.make_embedding(module.weight.detach().numpy())
Expand All @@ -822,17 +833,21 @@ def make_model(self):
# SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm)
print("Reading final norm")
self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm")
elif isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size:
elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head):
# Checks (Hugging Face logic) or (GGUF logic)
# Language modeling head (SkipLayerNorm --> logits)
print("Reading LM head")
self.make_lm_head(module)

del model

def has_final_norm(self, module, model):
norm = hasattr(model.model, "norm") and module == model.model.norm
final_layernorm = hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
return norm or final_layernorm
# Hugging Face names
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
# GGUF names
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
return hf_norm or hf_final_layernorm or gguf_final_norm

def make_attention_mask_reformatting(self):
# Make nodes for the attention mask subgraphs that reformat the
Expand Down Expand Up @@ -1362,11 +1377,15 @@ def parse_extra_options(kv_items):
return kv_pairs


def create_model(model_name_or_path, output_dir, precision, execution_provider, cache_dir, **extra_options):
def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options):
# Create cache and output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
extra_kwargs = {} if os.path.exists(model_name_or_path) else {"cache_dir": cache_dir, "use_auth_token": True}
config = AutoConfig.from_pretrained(model_name_or_path, **extra_kwargs)

# Load model config
extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True}
hf_name = input_path if os.path.isdir(input_path) else model_name
config = AutoConfig.from_pretrained(hf_name, **extra_kwargs)

# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16
Expand All @@ -1381,29 +1400,42 @@ def create_model(model_name_or_path, output_dir, precision, execution_provider,
elif config.architectures[0] == "PhiForCausalLM":
onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {model_name_or_path} model is not currently supported.")
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

# Make ONNX model
onnx_model.make_model()
onnx_model.make_model(input_path)

# Save ONNX model
onnx_model.save_model(output_dir)

# Make GenAI config
onnx_model.make_genai_config(model_name_or_path, extra_kwargs, output_dir)
onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)

# Copy Hugging Face processing files to output folder
onnx_model.save_processing(model_name_or_path, extra_kwargs, output_dir)
onnx_model.save_processing(hf_name, extra_kwargs, output_dir)


def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
help="Model name in Hugging Face or path to folder on disk containing the Hugging Face config, model, tokenizer, etc.",
"--model_name",
required=False,
default=None,
help="Model name in Hugging Face. Do not use if providing an input path to a Hugging Face directory in -i/--input.",
)

parser.add_argument(
"-i",
"--input",
required=False,
default="",
help=textwrap.dedent("""\
Input model source. Currently supported options are:
hf_path: Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.
gguf_path: Path to float16/float32 GGUF file on disk containing the GGUF model
"""),
)

parser.add_argument(
Expand Down Expand Up @@ -1435,7 +1467,7 @@ def get_args():
required=False,
type=str,
default=os.path.join('.', 'cache_dir'),
help="Model cache directory (if providing model name and not folder path)",
help="Cache directory for Hugging Face files and temporary ONNX external data files",
)

parser.add_argument(
Expand Down Expand Up @@ -1465,4 +1497,4 @@ def get_args():
if __name__ == '__main__':
args = get_args()
extra_options = parse_extra_options(args.extra_options)
create_model(args.model_name_or_path, args.output, args.precision, args.execution_provider, args.cache_dir, **extra_options)
create_model(args.model_name, args.input, args.output, args.precision, args.execution_provider, args.cache_dir, **extra_options)
Loading

0 comments on commit adec01f

Please sign in to comment.