From 49d26b330a2a0aafeae09f95d5f2c22e24e69e34 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Wed, 19 Nov 2025 22:45:11 +0000 Subject: [PATCH 1/7] Immigrant from prev-refactor PR: https://github.com/microsoft/onnxruntime-genai/pull/1854 --- src/python/py/models/builder.py | 15 +++++-- src/python/py/models/builders/base.py | 56 ++++++++++++++++++++------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 2d802df43..4e5e4db70 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -391,11 +391,20 @@ def get_args(): Use this option when you want to exclude certain nodes from being quantized. Separate the node names with a ',' when passing them here (e.g. int4_nodes_to_exclude=/lm_head/MatMul,/model/embed_tokens/Gather) int4_algo_config = Method for int4 quantization. Default is 'default'. - Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'. + Currently supported options are: 'default', 'rtn', 'rtn_last', 'k_quant', 'k_quant_mixed', 'k_quant_last'. + rtn = RTN algorithm for int4 quantization. + rtn_last = RTN algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. + k_quant = k_quant algorithm for int4 quantization. k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8). k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. - int4_tied_embeddings = Enable weight sharing for quantization. Default is false. - Use this option when you want to share the weights in the embedding and unembedding. + int4_tied_embeddings = Enable weight sharing for quantized models (INT4/UINT4/INT8/UINT8). Default is false. + Use this option when you want to share the quantized weights between embedding and LM head layers. + Only works with rtn and k_quant quantization algorithms. + Cannot be used if LM head is excluded from quantization (use shared_embeddings instead). + shared_embeddings = Enable weight sharing for FP16/FP32/BF16 weights. Default is false. + Use this option when you want to share the float weights between embedding and LM head layers. + Works for pure FP models or INT4 models where LM head is excluded from quantization. + This reduces model size by eliminating duplicate weights. num_hidden_layers = Manually specify the number of layers in your ONNX model. Used for unit testing purposes. filename = Filename for ONNX model (default is 'model.onnx'). diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index d83326a53..3e99d4369 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -292,13 +292,21 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Create quantized attributes from quantization config self.quant_attrs["config"] = config.quantization_config self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False - - self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False - self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings) - self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"} - if not self.int8_lm_head: - # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. - self.int4_tied_embeddings = False + + lm_head_excluded = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] + + self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False) + self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} + # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. + # tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit, or with int8 lm_head + self.int4_tied_embeddings = ( + self.int4_tied_embeddings + and not lm_head_excluded + and (self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}) + ) + + # Check if shared embeddings are used for float embeddings and lm_head + self.shared_embeddings = extra_options.get("shared_embeddings", False) def to_str_dtype(self, dtype: ir.DataType) -> str: return dtype.name @@ -385,6 +393,10 @@ def make_attention_init(self): and not self.attention_attrs["k_norm"] ) + # Allow extra_options to override use_packed_matmul + if "unpack_matmul" in self.extra_options: + self.attention_attrs["use_packed_matmul"] = not self.extra_options.get("unpack_matmul", False) + # Some EPs don't support fusing rotary embeddings inside GQA yet self.attention_attrs["use_rope_in_attn"] = self.ep not in ["dml"] if self.attention_attrs["use_rope_in_attn"]: @@ -511,12 +523,17 @@ def make_int4_algo_config(self, quant_method: str): customized_weight_config = {} int4_algo_config = None - if quant_method == "rtn": - int4_algo_config = RTNWeightOnlyQuantConfig() + if quant_method in {"rtn", "rtn_last"}: + if quant_method == "rtn_last": + customized_weight_config["/lm_head/MatMul"] = {"bits": 8} + int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) - elif quant_method in {"k_quant_mixed", "k_quant_last"}: + elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig + if quant_method != "k_quant": + customized_weight_config["/lm_head/MatMul"] = {"bits": 8} + if quant_method == "k_quant_mixed": # k_quant_mixed is from llama.cpp. # Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136 @@ -1114,19 +1131,28 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs): def make_embedding(self, embedding): basename = "/model/embed_tokens" - if self.int4_tied_embeddings: + # Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4 + if self.int4_tied_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: gather_name = f"{basename}/GatherBlockQuantized" gather_output = f"{gather_name}/output_0" weight_reshape_name = f"{basename}/Reshape" bits = 8 if self.int8_lm_head else 4 - weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]"] + flat_dim = self.hidden_size * bits // 8 + weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]"] weight_reshape_output = f"{weight_reshape_name}/output_0" # quantized weight dtype is uint8, see here - # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 - self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=['vocab_size', 'hidden_size']) + # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 + self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim]) + self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size), gather_axis=0, quantize_axis=1) + elif self.shared_embeddings: + transpose_name = f"{basename}/Transpose" + transpose_output = f"{transpose_name}/output_0" + self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, [self.vocab_size, self.hidden_size], [1, 0]) - self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size)) + gather_name = f"{basename}/Gather" + gather_output = f"{gather_name}/output_0" + self.make_node('Gather', inputs=[transpose_output, 'input_ids'], outputs=[gather_output], name=gather_name) else: weight = "model.embed_tokens.weight" self.make_initializer(embedding, weight, to=self.io_dtype) From 5a31d03aa314edd296a737f9fa1ff6a12e138890 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Thu, 20 Nov 2025 02:12:04 +0000 Subject: [PATCH 2/7] Merged to . --- src/python/py/models/builder.py | 14 +++++--------- src/python/py/models/builders/base.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 4e5e4db70..f3b01a69e 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -60,7 +60,7 @@ def check_extra_options(kv_pairs, execution_provider): "use_qdq", "use_webgpu_fp32", "use_cuda_bf16", - "int4_tied_embeddings", + "shared_embeddings", "hf_remote", ] for key in bools: @@ -397,14 +397,10 @@ def get_args(): k_quant = k_quant algorithm for int4 quantization. k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8). k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. - int4_tied_embeddings = Enable weight sharing for quantized models (INT4/UINT4/INT8/UINT8). Default is false. - Use this option when you want to share the quantized weights between embedding and LM head layers. - Only works with rtn and k_quant quantization algorithms. - Cannot be used if LM head is excluded from quantization (use shared_embeddings instead). - shared_embeddings = Enable weight sharing for FP16/FP32/BF16 weights. Default is false. - Use this option when you want to share the float weights between embedding and LM head layers. - Works for pure FP models or INT4 models where LM head is excluded from quantization. - This reduces model size by eliminating duplicate weights. + shared_embeddings = Enable weight sharing between embedding and LM head layers. Default is false. + Use this option to share weights and reduce model size by eliminating duplicate weights. + For quantized models (INT4/UINT4): Shares quantized weights using GatherBlockQuantized. Only works with rtn and k_quant algorithms, and cannot be used if LM head is excluded. + For float models (FP16/FP32/BF16): Shares float weights using Gather. Works for pure FP models or INT4 models where LM head is excluded from quantization. num_hidden_layers = Manually specify the number of layers in your ONNX model. Used for unit testing purposes. filename = Filename for ONNX model (default is 'model.onnx'). diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 3e99d4369..c41891340 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -293,20 +293,20 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.quant_attrs["config"] = config.quantization_config self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False - lm_head_excluded = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] + exclude_embeds = extra_options.get("exclude_embeds", False) + exclude_lm_head = extra_options.get("exclude_lm_head", False) - self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False) + # Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized. + self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16} + self.shared_embeddings = extra_options.get("shared_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False) self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} - # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. - # tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit, or with int8 lm_head - self.int4_tied_embeddings = ( - self.int4_tied_embeddings - and not lm_head_excluded - and (self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}) - ) - # Check if shared embeddings are used for float embeddings and lm_head - self.shared_embeddings = extra_options.get("shared_embeddings", False) + # shared_embeddings conflicts with exclude_embeds and exclude_lm_head + if self.shared_embeddings and (exclude_embeds or exclude_lm_head): + self.shared_embeddings = False + elif self.shared_embeddings and not self.unquantized_lm_head: + # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. + self.shared_embeddings = self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"} def to_str_dtype(self, dtype: ir.DataType) -> str: return dtype.name @@ -1132,7 +1132,7 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs): def make_embedding(self, embedding): basename = "/model/embed_tokens" # Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4 - if self.int4_tied_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: + if self.shared_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: gather_name = f"{basename}/GatherBlockQuantized" gather_output = f"{gather_name}/output_0" @@ -1145,7 +1145,8 @@ def make_embedding(self, embedding): # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim]) self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size), gather_axis=0, quantize_axis=1) - elif self.shared_embeddings: + # Use Transpose + Gather for tied embeddings for float embedding layers + elif self.shared_embeddings and self.unquantized_lm_head: transpose_name = f"{basename}/Transpose" transpose_output = f"{transpose_name}/output_0" self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, [self.vocab_size, self.hidden_size], [1, 0]) From fca3c0c0ba9cf2074b627f67234522f4b62e2879 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Fri, 21 Nov 2025 10:59:18 +0000 Subject: [PATCH 3/7] Added comments for default; Moved KQuant config up; explictly defined shape&perm for transpose. --- src/python/py/models/builder.py | 1 + src/python/py/models/builders/base.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index f3b01a69e..c142fbfbd 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -392,6 +392,7 @@ def get_args(): Separate the node names with a ',' when passing them here (e.g. int4_nodes_to_exclude=/lm_head/MatMul,/model/embed_tokens/Gather) int4_algo_config = Method for int4 quantization. Default is 'default'. Currently supported options are: 'default', 'rtn', 'rtn_last', 'k_quant', 'k_quant_mixed', 'k_quant_last'. + default = algo_config passed to MatMulNBitsQuantizer is None. Quantizer uses default RTN algorithm. All MatMuls are quantized as int4.(different node naming conventions to `rtn`) rtn = RTN algorithm for int4 quantization. rtn_last = RTN algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. k_quant = k_quant algorithm for int4 quantization. diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index c41891340..4a3cf137b 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -20,6 +20,7 @@ MatMulNBitsQuantizer, QuantFormat, RTNWeightOnlyQuantConfig, + KQuantWeightOnlyQuantConfig, ) from tqdm import tqdm from transformers import ( @@ -529,7 +530,6 @@ def make_int4_algo_config(self, quant_method: str): int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: - from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig if quant_method != "k_quant": customized_weight_config["/lm_head/MatMul"] = {"bits": 8} @@ -1149,7 +1149,7 @@ def make_embedding(self, embedding): elif self.shared_embeddings and self.unquantized_lm_head: transpose_name = f"{basename}/Transpose" transpose_output = f"{transpose_name}/output_0" - self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, [self.vocab_size, self.hidden_size], [1, 0]) + self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, shape=[self.vocab_size, self.hidden_size], perm=[1, 0]) gather_name = f"{basename}/Gather" gather_output = f"{gather_name}/output_0" From 3aff1d8e4281a67606b5a69630b5882c58e481e6 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Fri, 21 Nov 2025 11:02:18 +0000 Subject: [PATCH 4/7] Erase deplicated obj vars definition --- src/python/py/models/builders/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 4a3cf137b..973c942d7 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -293,9 +293,6 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Create quantized attributes from quantization config self.quant_attrs["config"] = config.quantization_config self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False - - exclude_embeds = extra_options.get("exclude_embeds", False) - exclude_lm_head = extra_options.get("exclude_lm_head", False) # Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized. self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16} @@ -303,7 +300,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} # shared_embeddings conflicts with exclude_embeds and exclude_lm_head - if self.shared_embeddings and (exclude_embeds or exclude_lm_head): + if self.shared_embeddings and (self.exclude_embeds or self.exclude_lm_head): self.shared_embeddings = False elif self.shared_embeddings and not self.unquantized_lm_head: # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. From f9f78556cf6f32643174b4d8feba5b4a825c7ad0 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Sun, 23 Nov 2025 03:51:57 +0000 Subject: [PATCH 5/7] removed in extra_options --- src/python/py/models/builders/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 0eefa9eb8..42a898fb5 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -483,10 +483,6 @@ def make_attention_init(self): and not self.attention_attrs["k_norm"] ) - # Allow extra_options to override use_packed_matmul - if "unpack_matmul" in self.extra_options: - self.attention_attrs["use_packed_matmul"] = not self.extra_options.get("unpack_matmul", False) - # Some EPs don't support fusing rotary embeddings inside GQA yet self.attention_attrs["use_rope_in_attn"] = self.ep not in ["dml"] if self.attention_attrs["use_rope_in_attn"]: From 9d3bc929ee64a7a37302f07d51102e08a0a32028 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Sun, 23 Nov 2025 04:17:37 +0000 Subject: [PATCH 6/7] lint --- src/python/py/models/README.md | 40 ++++++++++++++++++ src/python/py/models/builders/base.py | 60 +++++++++++++++++++++------ 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index c45b293a0..5c2066d02 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -212,6 +212,46 @@ python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p p Note that this is the same as outputting embeddings since the last hidden states are also known as the embeddings. +#### Enable Shared Embeddings + +This scenario is for when you want to enable weight sharing between the embedding layer and the language modeling head. This reduces model size and can improve memory efficiency, especially useful for models with tied embeddings (where `tie_word_embeddings=true` in config.json). Shared embeddings are automatically enabled if `tie_word_embeddings=true` in the model's config.json (can be overridden with `shared_embeddings=false`), but cannot be used with `exclude_embeds=true` or `exclude_lm_head=true`. In `-p int4` case, works with RTN and K-quant quantization algorithms. + +##### option1: int4 +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true +``` + +##### option2: int4 + int8 embeddings +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last +``` + +##### option3: int4 + fp16 embeddings +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_nodes_to_exclude=["/lm_head/MatMul"] + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_nodes_to_exclude=["/lm_head/MatMul"] +``` + +##### option4: fp16(unquantized) +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p fp16 -e cuda --extra_options shared_embeddings=true + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p fp16 -e cuda --extra_options shared_embeddings=true +``` + #### Enable CUDA Graph This scenario is for when you want to enable CUDA graph for your ONNX model. diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 42a898fb5..7ded8e093 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -17,10 +17,10 @@ import torch from onnx_ir.tensor_adapters import TorchTensor, to_torch_dtype from onnxruntime.quantization.matmul_nbits_quantizer import ( + KQuantWeightOnlyQuantConfig, MatMulNBitsQuantizer, QuantFormat, RTNWeightOnlyQuantConfig, - KQuantWeightOnlyQuantConfig, ) from tqdm import tqdm from transformers import ( @@ -379,16 +379,30 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): ) # Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized. - self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16} - self.shared_embeddings = extra_options.get("shared_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False) - self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} - + self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"][ + "nodes_to_exclude" + ] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16} + self.shared_embeddings = extra_options.get( + "shared_embeddings", + config.tie_word_embeddings + if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None + else False, + ) + self.int8_lm_head = extra_options.get("int4_algo_config", "default") in { + "k_quant_mixed", + "k_quant_last", + "rtn_last", + } + # shared_embeddings conflicts with exclude_embeds and exclude_lm_head if self.shared_embeddings and (self.exclude_embeds or self.exclude_lm_head): self.shared_embeddings = False elif self.shared_embeddings and not self.unquantized_lm_head: # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. - self.shared_embeddings = self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"} + self.shared_embeddings = self.int8_lm_head or extra_options.get("int4_algo_config", "default") in { + "rtn", + "k_quant", + } def to_str_dtype(self, dtype: ir.DataType) -> str: return dtype.name @@ -637,7 +651,6 @@ def make_int4_algo_config(self, quant_method: str): int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: - if quant_method != "k_quant": customized_weight_config["/lm_head/MatMul"] = {"bits": 8} @@ -1276,21 +1289,42 @@ def make_embedding(self, embedding): weight_reshape_name = f"{basename}/Reshape" bits = 8 if self.int8_lm_head else 4 flat_dim = self.hidden_size * bits // 8 - weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]"] + weight_reshape_inputs = [ + f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", + f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]", + ] weight_reshape_output = f"{weight_reshape_name}/output_0" # quantized weight dtype is uint8, see here - # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 - self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim]) - self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size), gather_axis=0, quantize_axis=1) + # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 + self.make_reshape( + weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim] + ) + self.make_node( + "GatherBlockQuantized", + inputs=[weight_reshape_output, "input_ids", "lm_head.MatMul.weight_scale", "lm_head.MatMul.weight_zp"], + outputs=[gather_output], + name=gather_name, + domain="com.microsoft", + bits=bits, + block_size=int(self.int4_block_size), + gather_axis=0, + quantize_axis=1, + ) # Use Transpose + Gather for tied embeddings for float embedding layers elif self.shared_embeddings and self.unquantized_lm_head: transpose_name = f"{basename}/Transpose" transpose_output = f"{transpose_name}/output_0" - self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, shape=[self.vocab_size, self.hidden_size], perm=[1, 0]) + self.make_transpose( + transpose_name, + "lm_head.MatMul.weight", + self.io_dtype, + shape=[self.vocab_size, self.hidden_size], + perm=[1, 0], + ) gather_name = f"{basename}/Gather" gather_output = f"{gather_name}/output_0" - self.make_node('Gather', inputs=[transpose_output, 'input_ids'], outputs=[gather_output], name=gather_name) + self.make_node("Gather", inputs=[transpose_output, "input_ids"], outputs=[gather_output], name=gather_name) else: weight = "model.embed_tokens.weight" self.make_initializer(embedding, weight, to=self.io_dtype) From 81e2f710584a651061f1f7a934f13aacebf7e0d8 Mon Sep 17 00:00:00 2001 From: jixiongdeng Date: Tue, 25 Nov 2025 18:20:20 +0000 Subject: [PATCH 7/7] Updated document --- src/python/py/models/README.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 5c2066d02..60d5ecf3e 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -19,6 +19,7 @@ This folder contains the model builder for quickly creating optimized and quanti - [Exclude Embedding Layer](#exclude-embedding-layer) - [Exclude Language Modeling Head](#exclude-language-modeling-head) - [Include Last Hidden States Output](#include-last-hidden-states-output) + - [Enable Shared Embeddings](#enable-shared-embeddings) - [Enable CUDA Graph](#enable-cuda-graph) - [Use 8 Bits Quantization in QMoE](#use-8-bits-quantization-in-qmoe) - [Use QDQ Pattern for Quantization](#use-qdq-pattern-for-quantization) @@ -214,18 +215,18 @@ Note that this is the same as outputting embeddings since the last hidden states #### Enable Shared Embeddings -This scenario is for when you want to enable weight sharing between the embedding layer and the language modeling head. This reduces model size and can improve memory efficiency, especially useful for models with tied embeddings (where `tie_word_embeddings=true` in config.json). Shared embeddings are automatically enabled if `tie_word_embeddings=true` in the model's config.json (can be overridden with `shared_embeddings=false`), but cannot be used with `exclude_embeds=true` or `exclude_lm_head=true`. In `-p int4` case, works with RTN and K-quant quantization algorithms. +This scenario is for when you want to enable weight sharing between the embedding layer and the language modeling head. This reduces model size and can improve memory efficiency, especially useful for models with tied embeddings (where `tie_word_embeddings=true` in config.json). Shared embeddings are automatically enabled if `tie_word_embeddings=true` in the model's config.json (can be overridden with `shared_embeddings=false`), but cannot be used with `exclude_embeds=true` or `exclude_lm_head=true`. -##### option1: int4 +##### Option 1: INT4 (for RTN and K-Quant) ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant # From source: -python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true +python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant ``` -##### option2: int4 + int8 embeddings +##### Option 2: INT4 + INT8 embeddings (for RTN Last and K-Quant Last) ``` # From wheel: python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last @@ -234,16 +235,16 @@ python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_fold python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=k_quant_last ``` -##### option3: int4 + fp16 embeddings +##### Option 3: INT4 embeddings + FP16 embeddings ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_nodes_to_exclude=["/lm_head/MatMul"] +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=rtn int4_nodes_to_exclude=/lm_head/MatMul # From source: -python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_nodes_to_exclude=["/lm_head/MatMul"] +python3 builder.py -m model_name -o path_to_output_folder -p int4 -e cuda --extra_options shared_embeddings=true int4_algo_config=rtn int4_nodes_to_exclude=/lm_head/MatMul ``` -##### option4: fp16(unquantized) +##### Option 4: FP16 embeddings ``` # From wheel: python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p fp16 -e cuda --extra_options shared_embeddings=true