Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Vanilla and Quantized ChatGLM3 Models to Model Builder #921

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
97 changes: 77 additions & 20 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# Modifications Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@snnn, could you please help check the copyright?

"""
Run this script to create the desired ONNX model.
"""
Expand All @@ -21,16 +22,16 @@


class Model:
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.context_length = config.max_position_embeddings
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else config.max_position_embeddings
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else self.context_length
self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel
self.intermediate_size = config.intermediate_size
self.intermediate_size = config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.multi_query_group_num if hasattr(config, "multi_query_group_num") else config.num_attention_heads
self.num_attn_heads = config.num_attention_heads
self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers if hasattr(config, "num_hidden_layers") else config.num_layers
self.vocab_size = config.vocab_size
self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act

Expand Down Expand Up @@ -1432,14 +1433,21 @@
raise NotImplementedError(f"The MLP layer type is not set.")

def make_mlp_unpacked(self, layer_id, mlp, root_input):
packed_proj = getattr(mlp, "gate_up_proj", None) or getattr(
mlp, "dense_h_to_4h", None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these definitions be made in one line?

)
mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size)
mlp.gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :])
mlp.gate_proj.weight = torch.nn.Parameter(
packed_proj.weight[: self.intermediate_size, :]
)

mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size)
mlp.up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :])
mlp.up_proj.weight = torch.nn.Parameter(
packed_proj.weight[self.intermediate_size :, :]
)

# Delete original packed weights
del mlp.gate_up_proj
del packed_proj

def make_mlp_proj(self, layer_id, mlp, root_input):
# Make nodes for the MLP subgraph
Expand All @@ -1450,7 +1458,7 @@
# \ |
# \ ActFunc
# \ /
# Mul

Check warning

Code scanning / CodeQL

Unnecessary delete statement in function Warning

Unnecessary deletion of local variable
packed_proj
in function
make_mlp_unpacked
.
# |
# DownProjMatMul

Expand All @@ -1469,8 +1477,11 @@
self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size])

# Make output MatMul node
down_proj = getattr(mlp, "down_proj", None) or getattr(
mlp, "dense_4h_to_h", None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this definition be made in one line?

)
down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul"
down_name = self.make_matmul(mlp.down_proj, down_basename, f"{mul_name}/output_0")
down_name = self.make_matmul(down_proj, down_basename, f"{mul_name}/output_0")

# Assign output 0 of previous MatMul as skip input to next SkipLayerNorm
self.layernorm_attrs["skip_input"] = f"{down_name}/output_0"
Expand Down Expand Up @@ -1664,7 +1675,7 @@
return gelu_name

def make_activation(self, layer_id, root_input):
if self.activation in {"silu", "swish"}:
if self.activation in {"silu", "swish", "swiglu"}:
output_name = self.make_activation_with_mul(layer_id, root_input, activation="Sigmoid", domain=None)
elif self.activation in {"gelu_new", "gelu_fast", "gelu_pytorch_tanh"}:
output_name = self.make_gelu(layer_id, root_input, activation="FastGelu")
Expand Down Expand Up @@ -1744,7 +1755,17 @@
from onnxruntime_genai.models.quantized_model import QuantModel
q_size = self.num_attn_heads * self.head_size
kv_size = self.num_kv_heads * self.head_size
model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers)
model = QuantModel.from_pretrained(
self.quant_type,
input_path,
self.quant_attrs["bits"],
self.quant_attrs["group_size"],
self.quant_attrs["use_g_idx"],
q_size,
kv_size,
self.intermediate_size,
self.num_layers,
)
else:
# Load PyTorch model
extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {}
Expand All @@ -1753,6 +1774,7 @@
# Loop through model and map each module to ONNX/ORT ops
self.layer_id = 0
for module in model.modules():

if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding):
# Checks (Hugging Face logic) or (GGUF logic)
if not self.exclude_embeds:
Expand All @@ -1764,7 +1786,7 @@
self.layernorm_attrs["root_input"] = "inputs_embeds"
self.layernorm_attrs["skip_input"] = "inputs_embeds"

elif module.__class__.__name__.endswith("DecoderLayer") and self.layer_id < self.num_layers:
elif module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock") and self.layer_id < self.num_layers:
# Each decoder layer of model
print(f"Reading decoder layer {self.layer_id}")
self.make_layer(self.layer_id, module)
Expand All @@ -1774,7 +1796,7 @@
# SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm)
print("Reading final norm")
self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra line space between the conditions helps make the different cases more readable. Can this line deletion be reverted?

elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head):
# Checks (Hugging Face logic) or (GGUF logic)
if not self.exclude_lm_head:
Expand All @@ -1785,12 +1807,13 @@
del model

def has_final_norm(self, module, model):
# Hugging Face names
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
# GGUF names
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
return hf_norm or hf_final_layernorm or gguf_final_norm
# Hugging Face names
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm
# GGUF names
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
return hf_norm or hf_final_layernorm or hf_transformer_final_layernorm or gguf_final_norm

def make_preprocessing_nodes(self):
self.make_attention_mask_reformatting()
Expand Down Expand Up @@ -2613,6 +2636,36 @@
self.layernorm_attrs["last_layernorm"] = True


class ChatGLMModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.rotemb_attrs["num_heads"] = self.num_attn_heads
self.rotemb_attrs["partial_rotary_factor"] = 0.5 # Line 755 of modeling_chatglm.py check self.rotary_pos_emb declaration
self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"])
self.rotemb_attrs["interleaved"] = 1
self.attention_attrs["use_rotemb_in_attn"] = True
self.attention_attrs["use_packed_matmul"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combination of these attention_attrs settings will cause the FP32 CUDA model to be created incorrectly. For the FP32 CUDA model, the GroupQueryAttention op does not have an implementation. The fallback is to use the MultiHeadAttention op, which does not fuse rotary embeddings within it and does not handle the extra repeat_kv operation needed when num_attention_heads != num_key_value_heads. Therefore, the FP32 CUDA model would need the following settings.

  • use_rotemb_in_attn = False
  • use_packed_matmul = False
  • op_type = "MultiHeadAttention"

The subgraph for FP32 CUDA should look something like this.

# MultiHeadAttention example:
#
#               root_input
#              /     |     \
#       Q_MatMul  K_MatMul  V_MatMul  4D causal mask  past_key  past_value
#           |        |         |            |            |           |
#       Q_Rotary  K_Rotary     |            +------------+-----------+
#           |        |         |                         |
#           |     Repeat_K  Repeat_V                     |
#           \        |        /                          |
#            MultiHeadAttention--------------------------+
#                    |
#                 O_MatMul
#                    |
#                  O_Add

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relates to #880, for our use case, we would like GroupQueryAttention and our custom ep would have an implementation for it. I'm ok with changing the default behavior to adhere what is available in CPU/CUDA, as long as there is a way to override and generate with selected ops.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.attention_attrs["use_rotemb_in_attn"] = True and self.attention_attrs["use_packed_matmul"] = True lines can be deleted since these will be automatically set when the attention op is set.


def make_rotary_embedding(self, rotemb, name, root_input, **kwargs):
super().make_rotary_embedding(rotemb, name, root_input, num_heads=self.rotemb_attrs["num_heads"], rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs)

def make_attention(self, layer_id, attention, root_input, **kwargs):
if self.quant_type is None:
super().make_attention_unpacked(layer_id, attention, root_input, **kwargs)
# Add dummy rotary_emb attribute
attention.rotary_emb = type("RotaryEmbedding", (object,), {'content':{}})()
return super().make_attention(layer_id, attention, root_input, **kwargs)


def make_mlp_proj(self, layer_id, mlp, root_input):
if self.quant_type is None:
super().make_mlp_unpacked(layer_id, mlp, root_input)
super().make_mlp_proj(layer_id, mlp, root_input)

def make_layer(self, layer_id, layer):
layer.self_attn = layer.self_attn if hasattr(layer, 'self_attn') else layer.self_attention
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
super().make_layer(layer_id, layer)

def check_extra_options(kv_pairs):
if "use_8bits_moe" in kv_pairs:
assert(kv_pairs["use_8bits_moe"] == "1" or kv_pairs["use_8bits_moe"] == "0"), "use_8bits_moe must be 0 or 1."
Expand Down Expand Up @@ -2682,6 +2735,10 @@
onnx_model = Phi3VModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "Qwen2ForCausalLM":
onnx_model = QwenModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "ChatGLMForConditionalGeneration" or config.architectures[0] == "ChatGLMModel":
# Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter
config.hidden_act = "swiglu"
onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

Expand Down
Loading
Loading