Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gemma parity issue #5810

Merged
merged 1 commit into from
Mar 1, 2024
Merged

Conversation

kunal-vaishnavi
Copy link
Contributor

Description

This PR fixes a parity issue with Google's Gemma models by moving the addition of the unit offset to be after the dtype conversion.

Motivation and Context

The Gemma models from Hugging Face are loaded with torch.bfloat16 precision by default. When a unit add is performed on a torch.bfloat16 tensor, the following behavior occurs.

Example:

> import torch 
> t = torch.tensor(3.141592, dtype=torch.bfloat16)
> t
tensor(3.1406, dtype=torch.bfloat16)
> t + 1
tensor(4.1250, dtype=torch.bfloat16)

The value returned is 4.1250 instead of 4.1406 and it will remain this value even when the returned value is converted to torch.float16 or torch.float32.

> (t + 1).to(torch.float16)
tensor(4.1250, dtype=torch.float16)
> (t + 1).to(torch.float32)
tensor(4.1250)

If the unit add is performed after the dtype conversion, the value returned is the expected value.

> t = torch.tensor(3.141592, dtype=torch.bfloat16)
> t
tensor(3.1406, dtype=torch.bfloat16)
> t = t.to(torch.float32)
> t
tensor(3.1406)
> t + 1
tensor(4.1406)

When comparing the LayerNorm weights from the Hugging Face Gemma 2B model and the GGUF Gemma 2B model produced by convert-hf-to-gguf.py before this change, the tensor values are different. Each tensor below is of size 2048 and in float16 precision.

Tensor Name Hugging Face GGUF
Layer 0 Attention Norm [0. 3.469 1.469 ... 2.727 3.031 2.984] [0. 3.469 1.469 ... 2.719 3.031 2.984]
Layer 0 FFN Norm [1.4 1.93 1.723 ... 1.73 1.836 1.73] [1.398 1.93 1.719 ... 1.734 1.836 1.734]
Layer 1 Attention Norm [1.566 1.383 1.594 ... 1.471 1.354 1.238] [1.5625 1.383 1.594 ... 1.469 1.352 1.234]
Layer 1 FFN Norm [2.016 1.91 2.148 ... 1.918 2.11 1.73] [2.016 1.906 2.156 ... 1.922 2.11 1.734]
... ... ...
Layer 16 Attention Norm [1.758 2.523 2.148 ... 1.777 2.305 1.945] [1.758 2.531 2.156 ... 1.781 2.312 1.945]
Layer 16 FFN Norm [2.852 2.758 2.984 ... 2.922 2.977 2.914] [2.844 2.75 2.984 ... 2.922 2.969 2.906]
... ... ...
Layer 18 Attention Norm [1.305 1.832 1.59 ... 1.875 1.516 1.855] [1.305 1.828 1.594 ... 1.875 1.516 1.859]

After converting the GGUF model to ONNX and running a parity test with ONNX Runtime, ORT reports a parity mismatch for both prompt processing and token generation.

When comparing the LayerNorm weights from the Hugging Face Gemma 2B model and the GGUF Gemma 2B model produced by convert-hf-to-gguf.py after this change, the tensor values are matching. Each tensor below is of size 2048 and in float16 precision.

Tensor Name Hugging Face GGUF
Layer 0 Attention Norm [0. 3.469 1.469 ... 2.727 3.031 2.984] [0. 3.469 1.469 ... 2.727 3.031 2.984]
Layer 0 FFN Norm [1.4 1.93 1.723 ... 1.73 1.836 1.73] [1.4 1.93 1.723 ... 1.73 1.836 1.73]
Layer 1 Attention Norm [1.566 1.383 1.594 ... 1.471 1.354 1.238] [1.566 1.383 1.594 ... 1.471 1.354 1.238]
Layer 1 FFN Norm [2.016 1.91 2.148 ... 1.918 2.11 1.73] [2.016 1.91 2.148 ... 1.918 2.11 1.73]
... ... ...
Layer 16 Attention Norm [1.758 2.523 2.148 ... 1.777 2.305 1.945] [1.758 2.523 2.148 ... 1.777 2.305 1.945]
Layer 16 FFN Norm [2.852 2.758 2.984 ... 2.922 2.977 2.914] [2.852 2.758 2.984 ... 2.922 2.977 2.914]
... ... ...
Layer 18 Attention Norm [1.305 1.832 1.59 ... 1.875 1.516 1.855] [1.305 1.832 1.59 ... 1.875 1.516 1.855]

After converting the new GGUF model to ONNX and running the same parity test with ONNX Runtime, ORT reports that parity is achieved.

@ggerganov ggerganov merged commit e743386 into ggerganov:master Mar 1, 2024
22 of 23 checks passed
kunal-vaishnavi added a commit to microsoft/onnxruntime-genai that referenced this pull request Mar 1, 2024
### Description
This PR adds support for converting float16/float32 GGUF models to
optimized and quantized ONNX models via the model builder tool.

### Motivation and Context
[GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) is a
popular file format used in the
[`llama.cpp`](https://github.com/ggerganov/llama.cpp) project. The
project has multiple scripts to convert models to GGUF
([`convert.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert.py),
[`convert-hf-to-gguf.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py),
[`convert-llama-ggml-to-gguf.py`](https://github.com/ggerganov/llama.cpp/blob/master/convert-llama-ggml-to-gguf.py),
etc).

The conversion scripts apply for specific model architectures only. For
the currently supported architectures in the model builder tool, these
are the corresponding conversion scripts.
- LLaMA: `convert.py`
- Mistral: `convert.py`
- Phi-2: `convert-hf-to-gguf.py`
- Gemma: `convert-hf-to-gguf.py`

Depending on the conversion scripts, the weights are also stored
differently.
- `convert.py`
[permutes](https://github.com/ggerganov/llama.cpp/blob/d5ab29757ebc59a30f03e408294ec20628a6374e/convert.py#L565)
the [Q projection and K projection
weights](https://github.com/ggerganov/llama.cpp/blob/d5ab29757ebc59a30f03e408294ec20628a6374e/convert.py#L1186-L1187)
before storing them
- `convert-hf-to-gguf.py` stores the weights in their [original
order](https://github.com/ggerganov/llama.cpp/blob/c29af7e2252d288f2ea58a7d437c1cb7c0abf160/gguf-py/gguf/gguf_writer.py#L244)

New model architectures that are added to the project appear to use
`convert-hf-to-gguf.py` for conversion now.

### Notes About Gemma Models

There are two ways to obtain GGUF versions of Gemma: 1) download the
PyTorch model from Hugging Face and use `convert-hf-to-gguf.py` to
convert or 2) download Google's released GGUF versions from Hugging
Face.

#### Converting Gemma from Hugging Face to GGUF

For the Gemma GGUF models created from conversion, a parity mismatch was
discovered in the LayerNorm weights when comparing the converted GGUF
models and the PyTorch models in Hugging Face. For more details on this
error and the fix for the parity mismatch, please refer to [this
PR](ggerganov/llama.cpp#5810) in the `llama.cpp`
project.

Users should run `convert-hf-to-gguf.py` again to obtain the right
LayerNorm weights in the Gemma GGUF models.

#### Released GGUF Versions of Gemma
The Gemma GGUF models released on Hugging Face have a vocab size of
256128, which matches the vocab size specified in the [official
paper](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf).
However, the Gemma PyTorch models released on Hugging Face have a [vocab
size of
256000](https://huggingface.co/google/gemma-2b/blob/9d067f00def958594aaa16b39a65b07d69ca655b/config.json#L26).

This difference affects the size of the embeddings. Upon further
examination, the embeddings in the released GGUF models are padded. When
the padding is removed, the embeddings in both the released GGUF models
and the released PyTorch models have the same size and have parity.

It is possible that the released GGUF models were converted from
internal checkpoints instead of the released PyTorch checkpoints. This
could explain why the embeddings have different sizes and why there are
still some parity mismatches in other weights between the released GGUF
models and the released PyTorch models.
hazelnutcloud pushed a commit to hazelnutcloud/llama.cpp that referenced this pull request Mar 10, 2024
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants