Skip to content
71 changes: 53 additions & 18 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,21 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.quant_attrs["config"] = config.quantization_config
self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False

self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False
self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings)
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"}
if not self.int8_lm_head:
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
self.int4_tied_embeddings = False

lm_head_excluded = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"]

self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False)
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"}
Copy link
Contributor

@tianleiwu tianleiwu Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why int4_algo_config is used to set int8_lm_head here? Our model support different bits for different weights. I thought that it is better to have straight forward setting like weight name to n_bits, or a configuration for lm_head.

# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
# tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit, or with int8 lm_head
self.int4_tied_embeddings = (
self.int4_tied_embeddings
and not lm_head_excluded
and (self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"})
)

# Check if shared embeddings are used for float embeddings and lm_head
self.shared_embeddings = extra_options.get("shared_embeddings", False)

def to_str_dtype(self, dtype: ir.DataType) -> str:
return dtype.name
Expand Down Expand Up @@ -375,6 +384,10 @@ def make_attention_init(self):
and not self.attention_attrs["k_norm"]
)

# Allow extra_options to override use_packed_matmul
if "unpack_matmul" in self.extra_options:
self.attention_attrs["use_packed_matmul"] = not self.extra_options.get("unpack_matmul", False)

# Some EPs don't support fusing rotary embeddings inside GQA yet
self.attention_attrs["use_rope_in_attn"] = self.ep not in ["dml"]
if self.attention_attrs["use_rope_in_attn"]:
Expand Down Expand Up @@ -484,11 +497,15 @@ def make_int4_algo_config(self, quant_method: str):
customized_weight_config = {}
int4_algo_config = None

if quant_method == "rtn":
int4_algo_config = RTNWeightOnlyQuantConfig()
if quant_method in {"rtn", "rtn_last"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified to the following.

if quant_method in {"rtn", "rtn_last"}:
    if quant_method == "rtn_last":
        customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
    int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified to the following.

if quant_method in {"rtn", "rtn_last"}:
    if quant_method == "rtn_last":
        customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
    int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

Done.

if quant_method == "rtn_last":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

elif quant_method in {"k_quant_mixed", "k_quant_last"}:
elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified to the following.

elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
    if quant_method != "k_quant":
        customized_weight_config["/lm_head/MatMul"] = {"bits": 8}

    if quant_method == "k_quant_mixed":
        # k_quant_mixed is from llama.cpp.
        # Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
        # We also consider some MatMuls are more senstive to quantization than other MatMuls.
        layers_to_exclude = [
            i
            for i in range(self.num_layers)
            if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
        ]
        for i in layers_to_exclude:
            customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
            customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
            customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}

    int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified to the following.

elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
    if quant_method != "k_quant":
        customized_weight_config["/lm_head/MatMul"] = {"bits": 8}

    if quant_method == "k_quant_mixed":
        # k_quant_mixed is from llama.cpp.
        # Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
        # We also consider some MatMuls are more senstive to quantization than other MatMuls.
        layers_to_exclude = [
            i
            for i in range(self.num_layers)
            if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
        ]
        for i in layers_to_exclude:
            customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
            customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
            customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}

    int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

Done.

from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig
if quant_method != "k_quant":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}

if quant_method == "k_quant_mixed":
# k_quant_mixed is from llama.cpp.
Expand All @@ -504,7 +521,6 @@ def make_int4_algo_config(self, quant_method: str):
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}

customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

return int4_algo_config
Expand Down Expand Up @@ -1081,20 +1097,30 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):

def make_embedding(self, embedding):
basename = "/model/embed_tokens"
if self.int4_tied_embeddings:
# Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4
if self.int4_tied_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
gather_name = f"{basename}/GatherBlockQuantized"
gather_output = f"{gather_name}/output_0"

weight_reshape_name = f"{basename}/Reshape"
bits = 8 if self.int8_lm_head else 4
weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]"]
flat_dim = self.hidden_size * bits // 8
weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]"]
weight_reshape_output = f"{weight_reshape_name}/output_0"
# quantized weight dtype is uint8, see here
# https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=['vocab_size', 'hidden_size'])

self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size))
self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim])
self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size), gather_axis=0, quantize_axis=1)
elif self.shared_embeddings:
transpose_name = f"{basename}/Transpose"
transpose_output = f"{transpose_name}/output_0"
self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, [self.vocab_size, self.hidden_size], [1, 0])

gather_name = f"{basename}/Gather"
gather_output = f"{gather_name}/output_0"
self.make_node('Gather', inputs=[transpose_output, 'input_ids'], outputs=[gather_output], name=gather_name)
else:
#Default behavior: Separate emb_tokens and lmhead weights
weight = "model.embed_tokens.weight"
self.make_initializer(embedding, weight, to=self.io_dtype)

Expand Down Expand Up @@ -4660,11 +4686,20 @@ def get_args():
Use this option when you want to exclude certain nodes from being quantized.
Separate the node names with a ',' when passing them here (e.g. int4_nodes_to_exclude=/lm_head/MatMul,/model/embed_tokens/Gather)
int4_algo_config = Method for int4 quantization. Default is 'default'.
Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'.
Currently supported options are: 'default', 'rtn', 'rtn_last', 'k_quant', 'k_quant_mixed', 'k_quant_last'.
rtn = RTN algorithm for int4 quantization.
rtn_last = RTN algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
k_quant = k_quant algorithm for int4 quantization.
k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8).
k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
int4_tied_embeddings = Enable weight sharing for quantization. Default is false.
Use this option when you want to share the weights in the embedding and unembedding.
int4_tied_embeddings = Enable weight sharing for quantized models (INT4/UINT4/INT8/UINT8). Default is false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this option?

If we shared embeddings (lm_head), that also means the quantized weights shall be shared.

As long as we know quantization method and number of bits, that will be enough.

Use this option when you want to share the quantized weights between embedding and LM head layers.
Only works with rtn and k_quant quantization algorithms.
Cannot be used if LM head is excluded from quantization (use shared_embeddings instead).
shared_embeddings = Enable weight sharing for FP16/FP32/BF16 weights. Default is false.
Use this option when you want to share the float weights between embedding and LM head layers.
Works for pure FP models or INT4 models where LM head is excluded from quantization.
This reduces model size by eliminating duplicate weights.
num_hidden_layers = Manually specify the number of layers in your ONNX model.
Used for unit testing purposes.
filename = Filename for ONNX model (default is 'model.onnx').
Expand Down
Loading