-
Notifications
You must be signed in to change notification settings - Fork 255
Shared emb_tokens/lm_head on fp16 & uint4 weights #1854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4538bb3
6207e1f
958574d
b75bb1e
76113df
37cbec1
033923b
27ecca4
3046400
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -283,12 +283,21 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | |
| self.quant_attrs["config"] = config.quantization_config | ||
| self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False | ||
|
|
||
| self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False | ||
| self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings) | ||
| self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"} | ||
| if not self.int8_lm_head: | ||
| # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. | ||
| self.int4_tied_embeddings = False | ||
|
|
||
| lm_head_excluded = "/lm_head/MatMul" in self.quant_attrs["int4"]["nodes_to_exclude"] | ||
|
|
||
| self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False) | ||
| self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} | ||
| # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. | ||
| # tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit, or with int8 lm_head | ||
| self.int4_tied_embeddings = ( | ||
| self.int4_tied_embeddings | ||
| and not lm_head_excluded | ||
| and (self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}) | ||
| ) | ||
|
|
||
| # Check if shared embeddings are used for float embeddings and lm_head | ||
| self.shared_embeddings = extra_options.get("shared_embeddings", False) | ||
|
|
||
| def to_str_dtype(self, dtype: ir.DataType) -> str: | ||
| return dtype.name | ||
|
|
@@ -375,6 +384,10 @@ def make_attention_init(self): | |
| and not self.attention_attrs["k_norm"] | ||
| ) | ||
|
|
||
| # Allow extra_options to override use_packed_matmul | ||
| if "unpack_matmul" in self.extra_options: | ||
| self.attention_attrs["use_packed_matmul"] = not self.extra_options.get("unpack_matmul", False) | ||
|
|
||
| # Some EPs don't support fusing rotary embeddings inside GQA yet | ||
| self.attention_attrs["use_rope_in_attn"] = self.ep not in ["dml"] | ||
| if self.attention_attrs["use_rope_in_attn"]: | ||
|
|
@@ -484,11 +497,15 @@ def make_int4_algo_config(self, quant_method: str): | |
| customized_weight_config = {} | ||
| int4_algo_config = None | ||
|
|
||
| if quant_method == "rtn": | ||
| int4_algo_config = RTNWeightOnlyQuantConfig() | ||
| if quant_method in {"rtn", "rtn_last"}: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be simplified to the following. if quant_method in {"rtn", "rtn_last"}:
if quant_method == "rtn_last":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Done. |
||
| if quant_method == "rtn_last": | ||
| customized_weight_config["/lm_head/MatMul"] = {"bits": 8} | ||
| int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) | ||
|
|
||
| elif quant_method in {"k_quant_mixed", "k_quant_last"}: | ||
| elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be simplified to the following. elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
if quant_method != "k_quant":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
if quant_method == "k_quant_mixed":
# k_quant_mixed is from llama.cpp.
# Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
# We also consider some MatMuls are more senstive to quantization than other MatMuls.
layers_to_exclude = [
i
for i in range(self.num_layers)
if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
]
for i in layers_to_exclude:
customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}
int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Done. |
||
| from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig | ||
| if quant_method != "k_quant": | ||
| customized_weight_config["/lm_head/MatMul"] = {"bits": 8} | ||
|
|
||
| if quant_method == "k_quant_mixed": | ||
| # k_quant_mixed is from llama.cpp. | ||
|
|
@@ -504,7 +521,6 @@ def make_int4_algo_config(self, quant_method: str): | |
| customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8} | ||
| customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8} | ||
|
|
||
| customized_weight_config["/lm_head/MatMul"] = {"bits": 8} | ||
| int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) | ||
|
|
||
| return int4_algo_config | ||
|
|
@@ -1081,20 +1097,30 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs): | |
|
|
||
| def make_embedding(self, embedding): | ||
| basename = "/model/embed_tokens" | ||
| if self.int4_tied_embeddings: | ||
| # Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4 | ||
| if self.int4_tied_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: | ||
| gather_name = f"{basename}/GatherBlockQuantized" | ||
| gather_output = f"{gather_name}/output_0" | ||
|
|
||
| weight_reshape_name = f"{basename}/Reshape" | ||
| bits = 8 if self.int8_lm_head else 4 | ||
| weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]"] | ||
| flat_dim = self.hidden_size * bits // 8 | ||
| weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]"] | ||
| weight_reshape_output = f"{weight_reshape_name}/output_0" | ||
| # quantized weight dtype is uint8, see here | ||
| # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 | ||
| self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=['vocab_size', 'hidden_size']) | ||
|
|
||
| self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size)) | ||
| self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim]) | ||
| self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size), gather_axis=0, quantize_axis=1) | ||
| elif self.shared_embeddings: | ||
| transpose_name = f"{basename}/Transpose" | ||
| transpose_output = f"{transpose_name}/output_0" | ||
| self.make_transpose(transpose_name, "lm_head.MatMul.weight", self.io_dtype, [self.vocab_size, self.hidden_size], [1, 0]) | ||
|
|
||
| gather_name = f"{basename}/Gather" | ||
| gather_output = f"{gather_name}/output_0" | ||
| self.make_node('Gather', inputs=[transpose_output, 'input_ids'], outputs=[gather_output], name=gather_name) | ||
| else: | ||
| #Default behavior: Separate emb_tokens and lmhead weights | ||
| weight = "model.embed_tokens.weight" | ||
| self.make_initializer(embedding, weight, to=self.io_dtype) | ||
|
|
||
|
|
@@ -4660,11 +4686,20 @@ def get_args(): | |
| Use this option when you want to exclude certain nodes from being quantized. | ||
| Separate the node names with a ',' when passing them here (e.g. int4_nodes_to_exclude=/lm_head/MatMul,/model/embed_tokens/Gather) | ||
| int4_algo_config = Method for int4 quantization. Default is 'default'. | ||
| Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'. | ||
| Currently supported options are: 'default', 'rtn', 'rtn_last', 'k_quant', 'k_quant_mixed', 'k_quant_last'. | ||
| rtn = RTN algorithm for int4 quantization. | ||
| rtn_last = RTN algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. | ||
| k_quant = k_quant algorithm for int4 quantization. | ||
| k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8). | ||
| k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4. | ||
| int4_tied_embeddings = Enable weight sharing for quantization. Default is false. | ||
| Use this option when you want to share the weights in the embedding and unembedding. | ||
| int4_tied_embeddings = Enable weight sharing for quantized models (INT4/UINT4/INT8/UINT8). Default is false. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this option? If we shared embeddings (lm_head), that also means the quantized weights shall be shared. As long as we know quantization method and number of bits, that will be enough. |
||
| Use this option when you want to share the quantized weights between embedding and LM head layers. | ||
| Only works with rtn and k_quant quantization algorithms. | ||
| Cannot be used if LM head is excluded from quantization (use shared_embeddings instead). | ||
| shared_embeddings = Enable weight sharing for FP16/FP32/BF16 weights. Default is false. | ||
| Use this option when you want to share the float weights between embedding and LM head layers. | ||
| Works for pure FP models or INT4 models where LM head is excluded from quantization. | ||
| This reduces model size by eliminating duplicate weights. | ||
| num_hidden_layers = Manually specify the number of layers in your ONNX model. | ||
| Used for unit testing purposes. | ||
| filename = Filename for ONNX model (default is 'model.onnx'). | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why int4_algo_config is used to set int8_lm_head here? Our model support different bits for different weights. I thought that it is better to have straight forward setting like weight name to n_bits, or a configuration for lm_head.