-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Support for Vanilla and Quantized ChatGLM3 Models to Model Builder #921
base: main
Are you sure you want to change the base?
Changes from all commits
d1b26a2
a51d036
bc46b1c
1c90219
dfdbf4f
b20e07b
0bdf843
eb10e51
938535c
5e39727
cfac49c
e70dcfb
44a5178
d8eb982
e116fa4
f768673
7eee24d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
# Modifications Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved | ||
""" | ||
Run this script to create the desired ONNX model. | ||
""" | ||
|
@@ -21,16 +22,16 @@ | |
|
||
|
||
class Model: | ||
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | ||
self.context_length = config.max_position_embeddings | ||
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else config.max_position_embeddings | ||
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | ||
self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings | ||
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else self.context_length | ||
self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel | ||
self.intermediate_size = config.intermediate_size | ||
self.intermediate_size = config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size | ||
self.hidden_size = config.hidden_size | ||
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads | ||
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.multi_query_group_num if hasattr(config, "multi_query_group_num") else config.num_attention_heads | ||
self.num_attn_heads = config.num_attention_heads | ||
self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | ||
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers | ||
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers if hasattr(config, "num_hidden_layers") else config.num_layers | ||
self.vocab_size = config.vocab_size | ||
self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act | ||
|
||
|
@@ -1432,14 +1433,21 @@ | |
raise NotImplementedError(f"The MLP layer type is not set.") | ||
|
||
def make_mlp_unpacked(self, layer_id, mlp, root_input): | ||
packed_proj = getattr(mlp, "gate_up_proj", None) or getattr( | ||
mlp, "dense_h_to_4h", None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these definitions be made in one line? |
||
) | ||
mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) | ||
mlp.gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :]) | ||
mlp.gate_proj.weight = torch.nn.Parameter( | ||
packed_proj.weight[: self.intermediate_size, :] | ||
) | ||
|
||
mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) | ||
mlp.up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :]) | ||
mlp.up_proj.weight = torch.nn.Parameter( | ||
packed_proj.weight[self.intermediate_size :, :] | ||
) | ||
|
||
# Delete original packed weights | ||
del mlp.gate_up_proj | ||
del packed_proj | ||
|
||
def make_mlp_proj(self, layer_id, mlp, root_input): | ||
# Make nodes for the MLP subgraph | ||
|
@@ -1450,7 +1458,7 @@ | |
# \ | | ||
# \ ActFunc | ||
# \ / | ||
# Mul | ||
Check warning Code scanning / CodeQL Unnecessary delete statement in function Warning
Unnecessary deletion of local variable
packed_proj Error loading related location Loading make_mlp_unpacked Error loading related location Loading |
||
# | | ||
# DownProjMatMul | ||
|
||
|
@@ -1469,8 +1477,11 @@ | |
self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) | ||
|
||
# Make output MatMul node | ||
down_proj = getattr(mlp, "down_proj", None) or getattr( | ||
mlp, "dense_4h_to_h", None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this definition be made in one line? |
||
) | ||
down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" | ||
down_name = self.make_matmul(mlp.down_proj, down_basename, f"{mul_name}/output_0") | ||
down_name = self.make_matmul(down_proj, down_basename, f"{mul_name}/output_0") | ||
|
||
# Assign output 0 of previous MatMul as skip input to next SkipLayerNorm | ||
self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" | ||
|
@@ -1664,7 +1675,7 @@ | |
return gelu_name | ||
|
||
def make_activation(self, layer_id, root_input): | ||
if self.activation in {"silu", "swish"}: | ||
if self.activation in {"silu", "swish", "swiglu"}: | ||
output_name = self.make_activation_with_mul(layer_id, root_input, activation="Sigmoid", domain=None) | ||
elif self.activation in {"gelu_new", "gelu_fast", "gelu_pytorch_tanh"}: | ||
output_name = self.make_gelu(layer_id, root_input, activation="FastGelu") | ||
|
@@ -1744,7 +1755,17 @@ | |
from onnxruntime_genai.models.quantized_model import QuantModel | ||
q_size = self.num_attn_heads * self.head_size | ||
kv_size = self.num_kv_heads * self.head_size | ||
model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) | ||
model = QuantModel.from_pretrained( | ||
self.quant_type, | ||
input_path, | ||
self.quant_attrs["bits"], | ||
self.quant_attrs["group_size"], | ||
self.quant_attrs["use_g_idx"], | ||
q_size, | ||
kv_size, | ||
self.intermediate_size, | ||
self.num_layers, | ||
) | ||
else: | ||
# Load PyTorch model | ||
extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} | ||
|
@@ -1753,6 +1774,7 @@ | |
# Loop through model and map each module to ONNX/ORT ops | ||
self.layer_id = 0 | ||
for module in model.modules(): | ||
|
||
if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding): | ||
# Checks (Hugging Face logic) or (GGUF logic) | ||
if not self.exclude_embeds: | ||
|
@@ -1764,7 +1786,7 @@ | |
self.layernorm_attrs["root_input"] = "inputs_embeds" | ||
self.layernorm_attrs["skip_input"] = "inputs_embeds" | ||
|
||
elif module.__class__.__name__.endswith("DecoderLayer") and self.layer_id < self.num_layers: | ||
elif module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock") and self.layer_id < self.num_layers: | ||
# Each decoder layer of model | ||
print(f"Reading decoder layer {self.layer_id}") | ||
self.make_layer(self.layer_id, module) | ||
|
@@ -1774,7 +1796,7 @@ | |
# SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm) | ||
print("Reading final norm") | ||
self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The extra line space between the conditions helps make the different cases more readable. Can this line deletion be reverted? |
||
elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head): | ||
# Checks (Hugging Face logic) or (GGUF logic) | ||
if not self.exclude_lm_head: | ||
|
@@ -1785,12 +1807,13 @@ | |
del model | ||
|
||
def has_final_norm(self, module, model): | ||
# Hugging Face names | ||
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm | ||
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm | ||
# GGUF names | ||
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm | ||
return hf_norm or hf_final_layernorm or gguf_final_norm | ||
# Hugging Face names | ||
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm | ||
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm | ||
hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm | ||
# GGUF names | ||
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm | ||
return hf_norm or hf_final_layernorm or hf_transformer_final_layernorm or gguf_final_norm | ||
|
||
def make_preprocessing_nodes(self): | ||
self.make_attention_mask_reformatting() | ||
|
@@ -2613,6 +2636,36 @@ | |
self.layernorm_attrs["last_layernorm"] = True | ||
|
||
|
||
class ChatGLMModel(Model): | ||
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | ||
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) | ||
self.rotemb_attrs["num_heads"] = self.num_attn_heads | ||
self.rotemb_attrs["partial_rotary_factor"] = 0.5 # Line 755 of modeling_chatglm.py check self.rotary_pos_emb declaration | ||
self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"]) | ||
self.rotemb_attrs["interleaved"] = 1 | ||
self.attention_attrs["use_rotemb_in_attn"] = True | ||
self.attention_attrs["use_packed_matmul"] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The combination of these
The subgraph for FP32 CUDA should look something like this. # MultiHeadAttention example:
#
# root_input
# / | \
# Q_MatMul K_MatMul V_MatMul 4D causal mask past_key past_value
# | | | | | |
# Q_Rotary K_Rotary | +------------+-----------+
# | | | |
# | Repeat_K Repeat_V |
# \ | / |
# MultiHeadAttention--------------------------+
# |
# O_MatMul
# |
# O_Add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This relates to #880, for our use case, we would like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): | ||
super().make_rotary_embedding(rotemb, name, root_input, num_heads=self.rotemb_attrs["num_heads"], rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs) | ||
|
||
def make_attention(self, layer_id, attention, root_input, **kwargs): | ||
if self.quant_type is None: | ||
super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) | ||
# Add dummy rotary_emb attribute | ||
attention.rotary_emb = type("RotaryEmbedding", (object,), {'content':{}})() | ||
return super().make_attention(layer_id, attention, root_input, **kwargs) | ||
|
||
|
||
def make_mlp_proj(self, layer_id, mlp, root_input): | ||
if self.quant_type is None: | ||
super().make_mlp_unpacked(layer_id, mlp, root_input) | ||
super().make_mlp_proj(layer_id, mlp, root_input) | ||
|
||
def make_layer(self, layer_id, layer): | ||
layer.self_attn = layer.self_attn if hasattr(layer, 'self_attn') else layer.self_attention | ||
kunal-vaishnavi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().make_layer(layer_id, layer) | ||
|
||
def check_extra_options(kv_pairs): | ||
if "use_8bits_moe" in kv_pairs: | ||
assert(kv_pairs["use_8bits_moe"] == "1" or kv_pairs["use_8bits_moe"] == "0"), "use_8bits_moe must be 0 or 1." | ||
|
@@ -2682,6 +2735,10 @@ | |
onnx_model = Phi3VModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) | ||
elif config.architectures[0] == "Qwen2ForCausalLM": | ||
onnx_model = QwenModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) | ||
elif config.architectures[0] == "ChatGLMForConditionalGeneration" or config.architectures[0] == "ChatGLMModel": | ||
# Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter | ||
config.hidden_act = "swiglu" | ||
onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) | ||
else: | ||
raise NotImplementedError(f"The {hf_name} model is not currently supported.") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@snnn, could you please help check the copyright?