Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ This folder contains the model builder for quickly creating optimized and quanti
- [Exclude Embedding Layer](#exclude-embedding-layer)
- [Exclude Language Modeling Head](#exclude-language-modeling-head)
- [Include Last Hidden States Output](#include-last-hidden-states-output)
- [Enable Shared Embeddings](#enable-shared-embeddings)
- [Enable CUDA Graph](#enable-cuda-graph)
- [Use 8 Bits Quantization in QMoE](#use-8-bits-quantization-in-qmoe)
- [Use QDQ Pattern for Quantization](#use-qdq-pattern-for-quantization)
Expand Down Expand Up @@ -212,6 +213,46 @@ python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p p

Note that this is the same as outputting embeddings since the last hidden states are also known as the embeddings.

#### Enable Shared Embeddings

This scenario is for when you want to enable weight sharing between the embedding layer and the language modeling head. This reduces model size and can improve memory efficiency, especially useful for models with tied embeddings (where `tie_word_embeddings=true` in config.json). Shared embeddings are automatically enabled if `tie_word_embeddings=true` in the model's config.json (can be overridden with `shared_embeddings=false`), but cannot be used with `exclude_embeds=true` or `exclude_lm_head=true`.

##### Option 1: INT4 (for RTN and K-Quant)
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant

# From source:
python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant
```

##### Option 2: INT4 + INT8 embeddings (for RTN Last and K-Quant Last)
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last

# From source:
python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last
```

##### Option 3: INT4 embeddings + FP16 embeddings
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=rtn int4_nodes_to_exclude=/lm_head/MatMul

# From source:
python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=rtn int4_nodes_to_exclude=/lm_head/MatMul
```

##### Option 4: FP16 embeddings
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p fp16 -e cuda --extra_options shared_embeddings=true

# From source:
python3 builder.py -m model_name -o path_to_output_folder -p fp16 -e cuda --extra_options shared_embeddings=true
```

#### Enable CUDA Graph

This scenario is for when you want to enable CUDA graph for your ONNX model.
Expand Down
14 changes: 10 additions & 4 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def check_extra_options(kv_pairs, execution_provider):
"use_qdq",
"use_webgpu_fp32",
"use_cuda_bf16",
"int4_tied_embeddings",
"shared_embeddings",
"hf_remote",
]
for key in bools:
Expand Down Expand Up @@ -390,11 +390,17 @@ def get_args():
Use this option when you want to exclude certain nodes from being quantized.
Separate the node names with a ',' when passing them here (e.g. int4_nodes_to_exclude=/lm_head/MatMul,/model/embed_tokens/Gather)
int4_algo_config = Method for int4 quantization. Default is 'default'.
Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'.
Currently supported options are: 'default', 'rtn', 'rtn_last', 'k_quant', 'k_quant_mixed', 'k_quant_last'.
default = algo_config passed to MatMulNBitsQuantizer is None. Quantizer uses default RTN algorithm. All MatMuls are quantized as int4.(different node naming conventions to `rtn`)
rtn = RTN algorithm for int4 quantization.
rtn_last = RTN algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
k_quant = k_quant algorithm for int4 quantization.
k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8).
k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
int4_tied_embeddings = Enable weight sharing for quantization. Default is false.
Use this option when you want to share the weights in the embedding and unembedding.
shared_embeddings = Enable weight sharing between embedding and LM head layers. Default is false.
Use this option to share weights and reduce model size by eliminating duplicate weights.
For quantized models (INT4/UINT4): Shares quantized weights using GatherBlockQuantized. Only works with rtn and k_quant algorithms, and cannot be used if LM head is excluded.
For float models (FP16/FP32/BF16): Shares float weights using Gather. Works for pure FP models or INT4 models where LM head is excluded from quantization.
num_hidden_layers = Manually specify the number of layers in your ONNX model.
Used for unit testing purposes.
filename = Filename for ONNX model (default is 'model.onnx').
Expand Down
65 changes: 51 additions & 14 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from onnx_ir.tensor_adapters import TorchTensor, to_torch_dtype
from onnxruntime.quantization.matmul_nbits_quantizer import (
KQuantWeightOnlyQuantConfig,
MatMulNBitsQuantizer,
QuantFormat,
RTNWeightOnlyQuantConfig,
Expand Down Expand Up @@ -377,16 +378,31 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False
)

self.int4_tied_embeddings = (
# Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized.
self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"][
"nodes_to_exclude"
] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}
self.shared_embeddings = extra_options.get(
"shared_embeddings",
config.tie_word_embeddings
if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None
else False
else False,
)
self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings)
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"}
if not self.int8_lm_head:
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {
"k_quant_mixed",
"k_quant_last",
"rtn_last",
}

# shared_embeddings conflicts with exclude_embeds and exclude_lm_head
if self.shared_embeddings and (self.exclude_embeds or self.exclude_lm_head):
self.shared_embeddings = False
elif self.shared_embeddings and not self.unquantized_lm_head:
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
self.int4_tied_embeddings = False
self.shared_embeddings = self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {
"rtn",
"k_quant",
}

def to_str_dtype(self, dtype: ir.DataType) -> str:
return dtype.name
Expand Down Expand Up @@ -629,11 +645,14 @@ def make_int4_algo_config(self, quant_method: str):
customized_weight_config = {}
int4_algo_config = None

if quant_method == "rtn":
int4_algo_config = RTNWeightOnlyQuantConfig()
if quant_method in {"rtn", "rtn_last"}:
if quant_method == "rtn_last":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

elif quant_method in {"k_quant_mixed", "k_quant_last"}:
from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig
elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
if quant_method != "k_quant":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}

if quant_method == "k_quant_mixed":
# k_quant_mixed is from llama.cpp.
Expand Down Expand Up @@ -1262,23 +1281,24 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):

def make_embedding(self, embedding):
basename = "/model/embed_tokens"
if self.int4_tied_embeddings:
# Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4
if self.shared_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
gather_name = f"{basename}/GatherBlockQuantized"
gather_output = f"{gather_name}/output_0"

weight_reshape_name = f"{basename}/Reshape"
bits = 8 if self.int8_lm_head else 4
flat_dim = self.hidden_size * bits // 8
weight_reshape_inputs = [
f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}",
f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]",
f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]",
]
weight_reshape_output = f"{weight_reshape_name}/output_0"
# quantized weight dtype is uint8, see here
# https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
self.make_reshape(
weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=["vocab_size", "hidden_size"]
weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim]
)

self.make_node(
"GatherBlockQuantized",
inputs=[weight_reshape_output, "input_ids", "lm_head.MatMul.weight_scale", "lm_head.MatMul.weight_zp"],
Expand All @@ -1287,7 +1307,24 @@ def make_embedding(self, embedding):
domain="com.microsoft",
bits=bits,
block_size=int(self.int4_block_size),
gather_axis=0,
quantize_axis=1,
)
# Use Transpose + Gather for tied embeddings for float embedding layers
elif self.shared_embeddings and self.unquantized_lm_head:
transpose_name = f"{basename}/Transpose"
transpose_output = f"{transpose_name}/output_0"
self.make_transpose(
transpose_name,
"lm_head.MatMul.weight",
self.io_dtype,
shape=[self.vocab_size, self.hidden_size],
perm=[1, 0],
)

gather_name = f"{basename}/Gather"
gather_output = f"{gather_name}/output_0"
self.make_node("Gather", inputs=[transpose_output, "input_ids"], outputs=[gather_output], name=gather_name)
else:
weight = "model.embed_tokens.weight"
self.make_initializer(embedding, weight, to=self.io_dtype)
Expand Down
Loading