From 6c260b9075dca80d6fc436ab2f76974a00e0b700 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 01:01:36 +0000 Subject: [PATCH 01/18] add builder for Qwen2_5_VLTextModel --- src/python/py/models/builder.py | 9 +- src/python/py/models/builders/__init__.py | 4 +- src/python/py/models/builders/base.py | 8 +- src/python/py/models/builders/phi.py | 1 - src/python/py/models/builders/qwen.py | 605 +++++++++++++++++- test/python/models/qwen_2.5_vl/run.sh | 59 ++ .../models/qwen_2.5_vl/test_qwen_2.5_vl.py | 451 +++++++++++++ 7 files changed, 1128 insertions(+), 9 deletions(-) create mode 100644 test/python/models/qwen_2.5_vl/run.sh create mode 100644 test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 2d802df43d..3f9e982b49 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -39,6 +39,7 @@ Phi3VModel, Phi4MMModel, PhiModel, + Qwen25VLTextModel, Qwen3Model, QwenModel, SmolLM3Model, @@ -292,6 +293,12 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "SmolLM3ForCausalLM": onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration": + print( + "WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default." + ) + extra_options["exclude_embeds"] = True + onnx_model = Qwen25VLTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config_only: # Create base Model class to guess model attributes onnx_model = Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) @@ -300,7 +307,7 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid if not config_only: # Make ONNX model - onnx_model.make_model(input_path) + onnx_model.make_model(input_path, config) # Save ONNX model onnx_model.save_model(output_dir) diff --git a/src/python/py/models/builders/__init__.py b/src/python/py/models/builders/__init__.py index 5e0a0ccfbb..6f606fd5ae 100644 --- a/src/python/py/models/builders/__init__.py +++ b/src/python/py/models/builders/__init__.py @@ -6,7 +6,7 @@ from .base import Model from .llama import LlamaModel from .mistral import MistralModel -from .qwen import QwenModel, Qwen3Model +from .qwen import QwenModel, Qwen3Model, Qwen25VLTextModel from .phi import ( PhiModel, Phi3MiniModel, Phi3MiniLongRoPEModel, Phi3SmallModel, Phi3SmallLongRoPEModel, Phi3VModel, Phi3MoELongRoPEModel, Phi4MMModel @@ -22,7 +22,7 @@ __all__ = [ "Model", - "LlamaModel", "MistralModel", "QwenModel", "Qwen3Model", "PhiModel", + "LlamaModel", "MistralModel", "QwenModel", "Qwen3Model", "Qwen25VLTextModel", "PhiModel", "Phi3MiniModel", "Phi3MiniLongRoPEModel", "Phi3SmallModel", "Phi3SmallLongRoPEModel", "Phi3VModel", "Phi3MoELongRoPEModel", "Phi4MMModel", "GemmaModel", "Gemma2Model", "Gemma3Model", "NemotronModel", "ChatGLMModel", diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index d83326a53c..b856b2eb2f 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -360,6 +360,11 @@ def make_rope_init(self, config): "ntk_alpha": beta_slow, "ntk_beta": beta_fast, } + elif "mrope_section" in config.rope_scaling: + # For models that use MRoPE (e.g. Qwen 2.5 VL) + self.rope_attrs["mrope"] = { + "sections": config.rope_scaling["mrope_section"], # Sections for MRoPE + } def make_attention_init(self): valid_gqa_configurations = { @@ -2989,7 +2994,6 @@ def make_model(self, input_path): q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size model = QuantModel.from_pretrained(self.quant_type, input_path=input_path, quant_attrs=self.quant_attrs, q_size=q_size, kv_size=kv_size, intermediate_size=self.intermediate_size, num_layers=self.num_layers) - else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} @@ -3515,4 +3519,4 @@ def make_attention_mask_reformatting_for_sparse_attn(self): def make_position_ids_reformatting(self): # For most cases, position_ids are already properly formatted as 2D tensors # with int64 values matching input_ids shape, so we can use them directly - return "position_ids" \ No newline at end of file + return "position_ids" diff --git a/src/python/py/models/builders/phi.py b/src/python/py/models/builders/phi.py index ae13ad2745..8aa8eae8ab 100644 --- a/src/python/py/models/builders/phi.py +++ b/src/python/py/models/builders/phi.py @@ -355,7 +355,6 @@ def make_layer(self, layer_id, layer): # Norm after last decoder layer of model (last layer --> norm) self.layernorm_attrs["last_layernorm"] = True - class Phi4MMModel(Phi3VModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 574fd1ba25..baba2bcffa 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -3,9 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from .mistral import MistralModel -class QwenModel(MistralModel): +import os +from .base import Model # Changed this to match your new inheritance +import onnx_ir as ir +import torch + +class QwenModel(Model): # Changed this to match your new inheritance def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -16,4 +20,599 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): def make_attention_init(self): self.attention_attrs["q_norm"] = True self.attention_attrs["k_norm"] = True - super().make_attention_init() \ No newline at end of file + super().make_attention_init() + +class Qwen25VLTextModel(QwenModel): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + # We must extract the text_config for the text model's parameters + text_config_dict = config.text_config.to_dict() + + # Update the main config with text-specific parameters + # The base.Model class reads from the top-level config object + config.hidden_size = text_config_dict["hidden_size"] + config.intermediate_size = text_config_dict["intermediate_size"] + config.num_attention_heads = text_config_dict["num_attention_heads"] + config.num_hidden_layers = text_config_dict["num_hidden_layers"] + config.num_key_value_heads = text_config_dict["num_key_value_heads"] + config.rms_norm_eps = text_config_dict["rms_norm_eps"] + config.sliding_window = text_config_dict["sliding_window"] + config.rope_scaling = text_config_dict["rope_scaling"] + # Need this for attention_scaling calculation + if "original_max_position_embeddings" in text_config_dict: + config.original_max_position_embeddings = text_config_dict["original_max_position_embeddings"] + + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + + # The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32. + # By inheriting from `base.Model`, all `layernorm_attrs["cast"]` flags + # are `False`. This causes two problems: + # 1. Parity Error (FP32 model): The 47% mismatch you saw. + # 2. Type Mismatch Error (BF16 model): The `(float)` vs `(bfloat16)` error. + # + # SOLUTION: Manually set all `cast` flags to `True`. This forces the + # builder to cast bf16 inputs -> fp32, compute LN, and cast fp32 + # outputs -> bf16, matching the HF model and fixing both errors. + # + print("Forcing LayerNorm computation to float32 (and enabling all casts) for Qwen2.5-VL parity.") + self.layernorm_attrs["cast"]["use_fp32"] = True + self.layernorm_attrs["cast"]["root_input"] = True + self.layernorm_attrs["cast"]["skip_input"] = True + self.layernorm_attrs["cast"]["output_0"] = True + self.layernorm_attrs["cast"]["output_3"] = True + # + # Qwen2's RoPE *always* computes in float32. + # We must replicate this behavior. + print("Forcing RoPE computation to float32 for Qwen2.5-VL parity.") + if "rope_cast" not in self.attention_attrs: + self.attention_attrs["rope_cast"] = {} + self.attention_attrs["rope_cast"]["use_fp32"] = True + + # The base.Model.make_outputs_init() *always* casts logits to float32 + # if the io_dtype is bfloat16. This is to improve accuracy in general. + # + # PROBLEM: The HF model (Qwen2_5_VL) *does not* do this. It computes + # the lm_head MatMul in bfloat16 and returns bfloat16 logits. + # This causes the parity test (which compares bf16 vs fp32) to fail. + # + # SOLUTION: We must override the base model's decision and set the + # output logits type to match the io_dtype (bfloat16). + # + self.allow_bf16_logits = os.getenv("allow_bf16_logits") in ["1", "true", "True"] + if self.allow_bf16_logits and self.io_dtype == ir.DataType.BFLOAT16: + print("Fixing output logits precision. Setting output_types['logits'] to BFLOAT16 to match HF model.") + self.output_types["logits"] = ir.DataType.BFLOAT16 + + # Manually get the attention_scaling from the rope_config + # This replicates the logic from transformers.models.rope_utils._config_to_init_values + rope_type = "default" + if config.rope_scaling and "type" in config.rope_scaling: + # The config re-maps 'mrope' to 'default' + if config.rope_scaling["type"] != "mrope": + rope_type = config.rope_scaling["type"] + + if rope_type == "yarn": + factor = config.rope_scaling.get("factor", 1.0) + self.rope_attrs["attention_scaling"] = config.rope_scaling.get("attention_factor", (0.1 * torch.log(torch.tensor(factor)) + 1.0).item()) + elif rope_type == "longrope": + factor = config.rope_scaling.get("factor", 1.0) + orig_max_pos = config.original_max_position_embeddings + self.rope_attrs["attention_scaling"] = config.rope_scaling.get("attention_factor", torch.sqrt(1 + torch.log(torch.tensor(factor)) / torch.log(torch.tensor(orig_max_pos))).item()) + else: + self.rope_attrs["attention_scaling"] = 1.0 + + # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op + self.attention_attrs["use_rope_in_attn"] = False + + # Your inheritance change fixed this, but this check is harmless and safe. + if "position_ids" not in self.input_names: + print("Re-adding 'position_ids' to self.input_names.") + if "attention_mask" in self.input_names: + idx = self.input_names.index("attention_mask") + self.input_names.insert(idx + 1, "position_ids") + else: + self.input_names.append("position_ids") + + self.mrope_sections = self.rope_attrs.get("mrope", {}).get("sections", []) + if not self.mrope_sections: + raise ValueError("MRoPE sections not found in config.text_config.rope_scaling.mrope_section") + + # The HF logic is `mrope_section * 2`, not `[s * 2 for s in mrope_section]`. + # This results in [16, 24, 24, 16, 24, 24] + self.mrope_splits = self.mrope_sections * 2 + + if sum(self.mrope_splits) != self.head_size: + # The sum (128) should now correctly match self.head_size (128) + raise ValueError(f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})") + + # Force GroupQueryAttention for fp32 cuda, + # as base.py's make_attention_init doesn't include this combo. + if self.ep == "cuda" and self.io_dtype == ir.DataType.FLOAT: + self.attention_attrs["op_type"] = "GroupQueryAttention" + print("Forcing GroupQueryAttention (GQA) for FP32 CUDA.") + + if self.attention_attrs["op_type"] != "GroupQueryAttention": + raise ValueError(f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo.") + + # Create and save the inv_freq tensor + self.make_inv_freq_tensor() + + def make_inv_freq_tensor(self): + """ + Calculates and saves the `inv_freq` tensor as an initializer. + This is copied from base.py:make_rotary_embedding_caches_from_scratch + """ + dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size) + inv_freq = 1.0 / (self.rope_attrs["rescale_factors"] * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))) + + # The HF model expects H/2, not R/2 + if dim != self.head_size: + print(f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported.") + inv_freq = inv_freq[:(self.head_size // 2)] + + self.make_initializer(inv_freq, "model.inv_freq", to=ir.DataType.FLOAT) + print("Created and saved 'model.inv_freq' initializer.") + + + def make_inputs_and_outputs(self): + # Qwen2.5-VL uses 3D position_ids + self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"] + + # Call the base Model's make_inputs_and_outputs (skipping MistralModel's) + super(QwenModel, self).make_inputs_and_outputs() + + def make_dynamic_rope_caches(self, layer_id, basename): + """ + Re-implements Qwen2_5_VLRotaryEmbedding.forward using ONNX ops. + Takes 3D position_ids and inv_freq and dynamically creates + the cos/sin caches. + """ + pos_ids_name = "position_ids" + inv_freq_name = "model.inv_freq" + head_dim_half = self.head_size // 2 + + # Get Batch Size from position_ids.shape[1] + shape_pos_ids_name = f"{basename}/shape_pos_ids" + shape_pos_ids_output = f"{shape_pos_ids_name}/output_0" + self.make_shape(shape_pos_ids_name, pos_ids_name, [3]) + + gather_batch_size_name = f"{basename}/gather_batch_size" + gather_batch_size_output = f"{gather_batch_size_name}/output_0" + self.make_gather(gather_batch_size_name, [shape_pos_ids_output, "/model/constants/INT64/[1]"], ir.DataType.INT64, [1], axis=0) + + # Expand inv_freq: [H/2] -> [1, 1, H/2, 1] + unsqueeze_1_name = f"{basename}/inv_freq_unsqueeze_1" + unsqueeze_1_output = f"{unsqueeze_1_name}/output_0" + self.make_unsqueeze(unsqueeze_1_name, [inv_freq_name, "/model/constants/INT64/[0, 1, 3]"], ir.DataType.FLOAT, [1, 1, head_dim_half, 1]) + + # Create target shape for Expand: [3, B, H/2, 1] + concat_expand_shape_name = f"{basename}/concat_expand_shape" + concat_expand_shape_output = f"{concat_expand_shape_name}/output_0" + self.make_concat( + concat_expand_shape_name, + ["/model/constants/INT64/[3]", gather_batch_size_output, f"/model/constants/INT64/[{head_dim_half}, 1]"], + ir.DataType.INT64, + [4], + axis=0 + ) + + expand_name = f"{basename}/inv_freq_expand" + expand_output = f"{expand_name}/output_0" + self.make_expand(expand_name, [unsqueeze_1_output, concat_expand_shape_output], ir.DataType.FLOAT, [3, "batch_size", head_dim_half, 1]) + + # Expand position_ids: [3, B, S] -> [3, B, 1, S] + unsqueeze_2_name = f"{basename}/pos_ids_unsqueeze" + unsqueeze_2_output = f"{unsqueeze_2_name}/output_0" + self.make_unsqueeze(unsqueeze_2_name, [pos_ids_name, "/model/constants/INT64/[2]"], ir.DataType.INT64, [3, "batch_size", 1, "sequence_length"]) + + # Cast position_ids to float + cast_name = f"{basename}/pos_ids_cast" + cast_output = f"{cast_name}/output_0" + self.make_cast(cast_name, unsqueeze_2_output, ir.DataType.FLOAT, [3, "batch_size", 1, "sequence_length"]) + + # MatMul: [3, B, H/2, 1] @ [3, B, 1, S] -> [3, B, H/2, S] + matmul_name = f"{basename}/freqs_matmul" + matmul_output = f"{matmul_name}/output_0" + self.make_node("MatMul", [expand_output, cast_output], [matmul_output], name=matmul_name) + self.make_value(matmul_output, ir.DataType.FLOAT, [3, "batch_size", head_dim_half, "sequence_length"]) + + # Transpose: [3, B, H/2, S] -> [3, B, S, H/2] + transpose_name = f"{basename}/freqs_transpose" + transpose_output = f"{transpose_name}/output_0" + self.make_transpose(transpose_name, matmul_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", head_dim_half], perm=[0, 1, 3, 2]) + + # Concat (freqs, freqs): [3, B, S, H/2] -> [3, B, S, H] + concat_name = f"{basename}/emb_concat" + concat_output = f"{concat_name}/output_0" + self.make_concat(concat_name, [transpose_output, transpose_output], ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size], axis=-1) + + # Cos(emb) and Sin(emb) + cos_name = f"{basename}/cos" + cos_output = f"{cos_name}/output_0" + self.make_node("Cos", [concat_output], [cos_output], name=cos_name) + self.make_value(cos_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + + sin_name = f"{basename}/sin" + sin_output = f"{sin_name}/output_0" + self.make_node("Sin", [concat_output], [sin_output], name=sin_name) + self.make_value(sin_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + + # Apply attention_scaling + cos_final_output = cos_output + sin_final_output = sin_output + scale = self.rope_attrs.get("attention_scaling", 1.0) # Get from rope_attrs + + if scale != 1.0: + scale_const_name = f"/model/constants/FLOAT/{scale}" + + cos_mul_name = f"{basename}/cos_mul_scale" + cos_final_output = f"{cos_mul_name}/output_0" + self.make_node("Mul", [cos_output, scale_const_name], [cos_final_output], name=cos_mul_name) + self.make_value(cos_final_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + + sin_mul_name = f"{basename}/sin_mul_scale" + sin_final_output = f"{sin_mul_name}/output_0" + self.make_node("Mul", [sin_output, scale_const_name], [sin_final_output], name=sin_mul_name) + self.make_value(sin_final_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + + return cos_final_output, sin_final_output + + def rotate_half(self, x_name, x_shape, basename, compute_dtype): + """ + Builds ONNX nodes for rotate_half(x) + x_shape is [B, N, S, H] + """ + # Split: [B, N, S, H] -> [B, N, S, H/2], [B, N, S, H/2] + split_name = f"{basename}/rotate_half/Split" + split_output_0 = f"{split_name}/output_0" + split_output_1 = f"{split_name}/output_1" + self.make_node("Split", [x_name], [split_output_0, split_output_1], name=split_name, axis=-1, num_outputs=2) + half_shape = x_shape[:-1] + [x_shape[-1] // 2] + self.make_value(split_output_0, compute_dtype, half_shape) + self.make_value(split_output_1, compute_dtype, half_shape) + + # Negate x2 + neg_name = f"{basename}/rotate_half/Neg" + neg_output = f"{neg_name}/output_0" + self.make_node("Neg", [split_output_1], [neg_output], name=neg_name) + self.make_value(neg_output, compute_dtype, half_shape) + + # Concat (-x2, x1) + concat_name = f"{basename}/rotate_half/Concat" + concat_output = f"{concat_name}/output_0" + self.make_concat(concat_name, [neg_output, split_output_0], compute_dtype, x_shape, axis=-1) + + return concat_output + + def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): + """ + Re-implements apply_multimodal_rotary_pos_emb using ONNX ops. + Takes Q/K tensor and the dynamically generated 3D caches + and applies the rotation. + """ + + # --- Handle precision for RoPE --- + # Check if we need to force float32 computation + force_fp32 = self.attention_attrs.get("rope_cast", {}).get("use_fp32", False) + + # Set compute_dtype (precision for math) and output_dtype (final precision) + compute_dtype = ir.DataType.FLOAT if force_fp32 else self.io_dtype + output_dtype = self.io_dtype + # -------------------------------- + + # Create a Constant node for mrope_splits + # This holds the correct splits, e.g., [16, 24, 24, 16, 24, 24] + mrope_splits_node_name = f"{basename}/mrope_splits_node" + mrope_splits_output_name = f"{basename}/mrope_splits" + mrope_splits_tensor = ir.tensor( + torch.tensor(self.mrope_splits, dtype=torch.int64), + name=mrope_splits_output_name + ) + self.make_node( + "Constant", + inputs=[], + outputs=[mrope_splits_output_name], + name=mrope_splits_node_name, + value=mrope_splits_tensor + ) + self.make_value(mrope_splits_output_name, ir.DataType.INT64, [len(self.mrope_splits)]) + + # Split the dynamic caches [3, B, S, H] into 6 chunks on axis -1 + # Caches (dyn_cos, dyn_sin) are already in float32 + num_splits = len(self.mrope_splits) + + cos_split_name = f"{basename}/cos_split" + cos_split_outputs = [f"{cos_split_name}/output_{i}" for i in range(num_splits)] + self.make_node("Split", [dyn_cos, mrope_splits_output_name], cos_split_outputs, name=cos_split_name, axis=-1) + + sin_split_name = f"{basename}/sin_split" + sin_split_outputs = [f"{sin_split_name}/output_{i}" for i in range(num_splits)] + self.make_node("Split", [dyn_sin, mrope_splits_output_name], sin_split_outputs, name=sin_split_name, axis=-1) + + # Re-order the caches: [T, H, W, T, H, W] + cos_reordered = [] + sin_reordered = [] + for i in range(num_splits): + dim_chunk = self.mrope_splits[i] + cache_dim_to_use = i % 3 # 0 for T, 1 for H, 2 for W + + # Gather from dim 0 of the split cache chunk + # input is [3, B, S, H_chunk], indices is [0, 1, or 2] + gather_cos_name = f"{basename}/cos_gather_{i}" + gather_cos_output = f"{gather_cos_name}/output_0" + self.make_node("Gather", [cos_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], [gather_cos_output], name=gather_cos_name, axis=0) + self.make_value(gather_cos_output, ir.DataType.FLOAT, [1, "batch_size", "sequence_length", dim_chunk]) # Shape [1, B, S, H_chunk] + + gather_sin_name = f"{basename}/sin_gather_{i}" + gather_sin_output = f"{gather_sin_name}/output_0" + self.make_node("Gather", [sin_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], [gather_sin_output], name=gather_sin_name, axis=0) + self.make_value(gather_sin_output, ir.DataType.FLOAT, [1, "batch_size", "sequence_length", dim_chunk]) # Shape [1, B, S, H_chunk] + + # FIX: Squeeze the gathered cache to [B, S, H_chunk] + squeeze_cos_name = f"{basename}/cos_squeeze_{i}" + squeeze_cos_output = f"{squeeze_cos_name}/output_0" + self.make_squeeze(squeeze_cos_name, [gather_cos_output, "/model/constants/INT64/[0]"], ir.DataType.FLOAT, ["batch_size", "sequence_length", dim_chunk]) + + squeeze_sin_name = f"{basename}/sin_squeeze_{i}" + squeeze_sin_output = f"{squeeze_sin_name}/output_0" + self.make_squeeze(squeeze_sin_name, [gather_sin_output, "/model/constants/INT64/[0]"], ir.DataType.FLOAT, ["batch_size", "sequence_length", dim_chunk]) + + # Unsqueeze to add the NumHeads dim: [B, 1, S, H_chunk] + unsqueeze_cos_name = f"{basename}/cos_unsqueeze_{i}" + unsqueeze_cos_output = f"{unsqueeze_cos_name}/output_0" + self.make_unsqueeze(unsqueeze_cos_name, [squeeze_cos_output, "/model/constants/INT64/[1]"], ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", dim_chunk]) + cos_reordered.append(unsqueeze_cos_output) + + unsqueeze_sin_name = f"{basename}/sin_unsqueeze_{i}" + unsqueeze_sin_output = f"{unsqueeze_sin_name}/output_0" + self.make_unsqueeze(unsqueeze_sin_name, [squeeze_sin_output, "/model/constants/INT64/[1]"], ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", dim_chunk]) + sin_reordered.append(unsqueeze_sin_output) + + # Concat re-ordered chunks back to [B, 1, S, H] + final_cos_concat_name = f"{basename}/cos_final_concat" + final_cos_concat_output = f"{final_cos_concat_name}/output_0" + self.make_concat(final_cos_concat_name, cos_reordered, ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", self.head_size], axis=-1) + + final_sin_concat_name = f"{basename}/sin_final_concat" + final_sin_concat_output = f"{final_sin_concat_name}/output_0" + self.make_concat(final_sin_concat_name, sin_reordered, ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", self.head_size], axis=-1) + + # Caches (final_cos_concat_output, final_sin_concat_output) are now in float32 + + # Reshape input Q/K: [B, S, N*H] -> [B, N, S, H] + reshape_1_name = f"{basename}/q_or_k_reshape_1" + reshape_1_output = f"{reshape_1_name}/output_0" + reshape_1_target_shape_onnx = f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]" + reshape_1_target_shape_ort = ["batch_size", "sequence_length", num_heads, self.head_size] + self.make_reshape(reshape_1_name, [q_or_k_path, reshape_1_target_shape_onnx], self.io_dtype, reshape_1_target_shape_ort) + + # Transpose Q/K: [B, S, N, H] -> [B, N, S, H] + transpose_1_name = f"{basename}/q_or_k_transpose_1" + transpose_1_output = f"{transpose_1_name}/output_0" + transpose_1_target_shape = ["batch_size", num_heads, "sequence_length", self.head_size] + self.make_transpose(transpose_1_name, reshape_1_output, self.io_dtype, transpose_1_target_shape, perm=[0, 2, 1, 3]) + + # --- Start RoPE computation --- + q_or_k_compute_input = transpose_1_output + cos_cache_compute_input = final_cos_concat_output + sin_cache_compute_input = final_sin_concat_output + + if force_fp32 and self.io_dtype != ir.DataType.FLOAT: + # Cast Q/K (self.io_dtype) up to float32 + q_or_k_cast_name = f"{basename}/q_or_k_cast_fp32" + q_or_k_cast_output = f"{q_or_k_cast_name}/output_0" + self.make_cast(q_or_k_cast_name, transpose_1_output, compute_dtype, transpose_1_target_shape) + q_or_k_compute_input = q_or_k_cast_output + elif not force_fp32 and self.io_dtype != ir.DataType.FLOAT: + # Cast Caches (float32) down to self.io_dtype + cos_cache_cast_name = f"{basename}/cos_final_cast" + cos_cache_cast_output = f"{cos_cache_cast_name}/output_0" + self.make_cast(cos_cache_cast_name, final_cos_concat_output, compute_dtype, ["batch_size", 1, "sequence_length", self.head_size]) + cos_cache_compute_input = cos_cache_cast_output + + sin_cache_cast_name = f"{basename}/sin_final_cast" + sin_cache_cast_output = f"{sin_cache_cast_name}/output_0" + self.make_cast(sin_cache_cast_name, final_sin_concat_output, compute_dtype, ["batch_size", 1, "sequence_length", self.head_size]) + sin_cache_compute_input = sin_cache_cast_output + + # Apply rotation: (q * cos) + (rotate_half(q) * sin) + + # 1. (q * cos) + mul_1_name = f"{basename}/mul_1" + mul_1_output = f"{mul_1_name}/output_0" + self.make_mul(mul_1_name, [q_or_k_compute_input, cos_cache_compute_input], compute_dtype, transpose_1_target_shape) + + # 2. rotate_half(q) + rotated_half_q_name = self.rotate_half(q_or_k_compute_input, transpose_1_target_shape, basename, compute_dtype) + + # 3. (rotate_half(q) * sin) + mul_2_name = f"{basename}/mul_2" + mul_2_output = f"{mul_2_name}/output_0" + self.make_mul(mul_2_name, [rotated_half_q_name, sin_cache_compute_input], compute_dtype, transpose_1_target_shape) + + # 4. (q * cos) + (rotate_half(q) * sin) + add_name = f"{basename}/add" + add_output = f"{add_name}/output_0" + self.make_add(add_name, [mul_1_output, mul_2_output], compute_dtype, transpose_1_target_shape) + + # --- End RoPE computation --- + + add_output_final = add_output + if force_fp32 and self.io_dtype != ir.DataType.FLOAT: + # Cast result back down to self.io_dtype + add_cast_name = f"{basename}/add_cast_output" + add_cast_output = f"{add_cast_name}/output_0" + self.make_cast(add_cast_name, add_output, output_dtype, transpose_1_target_shape) + add_output_final = add_cast_output + + # Transpose back: [B, N, S, H] -> [B, S, N, H] + transpose_2_name = f"{basename}/q_or_k_transpose_2" + transpose_2_output = f"{transpose_2_name}/output_0" + self.make_transpose(transpose_2_name, add_output_final, output_dtype, reshape_1_target_shape_ort, perm=[0, 2, 1, 3]) + + # Reshape back: [B, S, N, H] -> [B, S, N*H] + reshape_2_name = f"{basename}/q_or_k_reshape_2" + reshape_2_output = f"{reshape_2_name}/output_0" + self.make_reshape(reshape_2_name, [transpose_2_output, f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]"], output_dtype, q_or_k_shape) + + return reshape_2_output + + def make_attention(self, layer_id, attention, root_input, **kwargs): + + # 1. Unpack QKV if necessary (e.g. qkv_proj) + super(QwenModel, self).make_attention_unpacked(layer_id, attention, root_input, **kwargs) + + # 2. Build Q/K/V MatMul and Add nodes + q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" + q_matmul_name = self.make_matmul(attention.q_proj, q_matmul_basename, root_input) + self.attention_attrs["q_path"] = f"{q_matmul_name}/output_0" + q_shape = ["batch_size", "sequence_length", self.num_attn_heads * self.head_size] + + k_matmul_basename = f"/model/layers.{layer_id}/attn/k_proj/MatMul" + k_matmul_name = self.make_matmul(attention.k_proj, k_matmul_basename, root_input) + self.attention_attrs["k_path"] = f"{k_matmul_name}/output_0" + k_shape = ["batch_size", "sequence_length", self.num_kv_heads * self.head_size] + + v_matmul_basename = f"/model/layers.{layer_id}/attn/v_proj/MatMul" + v_matmul_name = self.make_matmul(attention.v_proj, v_matmul_basename, root_input) + self.attention_attrs["v_path"] = f"{v_matmul_name}/output_0" + + # Handle biases + q_bias_exists = attention.q_proj.bias is not None and torch.count_nonzero(attention.q_proj.bias) > 0 + k_bias_exists = attention.k_proj.bias is not None and torch.count_nonzero(attention.k_proj.bias) > 0 + v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0 + + if q_bias_exists: + q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" + self.make_add_bias(attention.q_proj.bias, q_add_name, root_input=self.attention_attrs["q_path"]) + self.attention_attrs["q_path"] = f"{q_add_name}/output_0" + if k_bias_exists: + k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" + self.make_add_bias(attention.k_proj.bias, k_add_name, root_input=self.attention_attrs["k_path"]) + self.attention_attrs["k_path"] = f"{k_add_name}/output_0" + if v_bias_exists: + v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" + self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"]) + self.attention_attrs["v_path"] = f"{v_add_name}/output_0" + + # 3. Apply 3D RoPE (MRoPE) + cos_dynamic, sin_dynamic = self.make_dynamic_rope_caches(layer_id, basename=f"/model/layers.{layer_id}/attn/mrope_dynamic_cache") + + # Apply rotation to Q + self.attention_attrs["q_path"] = self.apply_mrope_rotation( + layer_id, + self.attention_attrs["q_path"], + q_shape, + cos_dynamic, + sin_dynamic, + self.num_attn_heads, + basename=f"/model/layers.{layer_id}/attn/q_mrope" + ) + + # Apply rotation to K + self.attention_attrs["k_path"] = self.apply_mrope_rotation( + layer_id, + self.attention_attrs["k_path"], + k_shape, + cos_dynamic, + sin_dynamic, + self.num_kv_heads, + basename=f"/model/layers.{layer_id}/attn/k_mrope" + ) + + # 4. Call GroupQueryAttention op + past_k = f"past_key_values.{layer_id}.key" + past_v = f"past_key_values.{layer_id}.value" + present_k = f"present.{layer_id}.key" + present_v = f"present.{layer_id}.value" + + attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" + self.make_attention_op( + attn_name, + q_path=self.attention_attrs["q_path"], + k_path=self.attention_attrs["k_path"], + v_path=self.attention_attrs["v_path"], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + # Pass empty strings for fused caches since we applied RoPE manually + cos_cache="", + sin_cache="", + **kwargs, + ) + + # 5. Build O-proj + o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' + o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" + o_weight = getattr(attention, o_proj) + o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") + + o_bias_exists = getattr(attention, o_proj).bias is not None + if o_bias_exists: + o_add_name = f"/model/layers.{layer_id}/attn/o_proj/Add" + o_bias = getattr(attention, o_proj).bias + self.make_add_bias(o_bias, o_add_name, root_input=f"{o_matmul_name}/output_0") + self.layernorm_attrs["skip_input"] = f"{o_add_name}/output_0" + else: + self.layernorm_attrs["skip_input"] = f"{o_matmul_name}/output_0" + + def make_model(self, input_path, config=None): + + # Make inputs and outputs to ONNX model + self.make_inputs_and_outputs() + + # Make pre-processing nodes + self.make_preprocessing_nodes() + + # Load the Hugging Face model + from transformers import Qwen2_5_VLForConditionalGeneration + print("Loading Qwen2_5_VLForConditionalGeneration model...") + hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_name_or_path, + config=config, + cache_dir=self.cache_dir, + token=self.hf_token, + trust_remote_code=self.hf_remote + ) + + # We only want to export the text model + model = hf_model.language_model + print(f"Isolated language_model ({model.__class__.__name__}) for ONNX export.") + + # Loop through model and map each module to ONNX/ORT ops + self.layer_id = 0 + + # The base.Model.make_model() loop expects modules from a standard causal LM, + # so we replicate its logic here but point to the correct modules in the hf_model + + # Handle Embeddings + if not self.exclude_embeds: + print("Reading embedding layer") + # The text model's embeddings are at model.embed_tokens + self.make_embedding(model.embed_tokens.weight) + else: + # When excluding embeds, the input is `inputs_embeds` + print("Skipping embedding layer, model will expect 'inputs_embeds'.") + self.layernorm_attrs["root_input"] = "inputs_embeds" + self.layernorm_attrs["skip_input"] = "inputs_embeds" + + # Handle Decoder Layers + for layer in model.layers: + if self.layer_id < self.num_layers: + print(f"Reading decoder layer {self.layer_id}") + self.make_layer(self.layer_id, layer) + self.layer_id += 1 + + # Handle Final Norm + if self.layer_id == self.num_layers and hasattr(model, "norm"): + print("Reading final norm") + self.make_layernorm(self.layer_id, model.norm, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") + + # Handle LM Head + if not self.exclude_lm_head: + # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model + print("Reading LM head") + self.make_lm_head(hf_model.lm_head) + + del model + del hf_model diff --git a/test/python/models/qwen_2.5_vl/run.sh b/test/python/models/qwen_2.5_vl/run.sh new file mode 100644 index 0000000000..fb7e339581 --- /dev/null +++ b/test/python/models/qwen_2.5_vl/run.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# --- +# This script builds and tests either an fp32, bf16 or fp16 model. Append -f to force export. +# +# Usage: ./run.sh [fp32|bf16|fp16] [-f] +# --- + +# Exit immediately if a command fails +set -e + +# 1. Validate Input +if [ "$1" != "fp32" ] && [ "$1" != "bf16" ] && [ "$1" != "fp16" ]; then + echo "Error: Invalid precision." + echo "Usage: $0 fp32|bf16|fp16" + exit 1 +fi + +# 2. Define variables based on input +PRECISION=$1 +OUTPUT_DIR="./qwen_${PRECISION}" +ONNX_MODEL_PATH="${OUTPUT_DIR}/model.onnx" +CACHE_DIR="./cache" +HF_MODEL="Qwen/Qwen2.5-VL-3B-Instruct" + +# Set the --bf16 or --fp16 flag for the test script +TEST_FLAG="" +if [ "$PRECISION" == "bf16" ]; then + TEST_FLAG="--bf16" +elif [ "$PRECISION" == "fp16" ]; then + TEST_FLAG="--fp16" +fi + +# 3. Remove output directory only if it exists and -f flag is provided. +if [ "$2" == "-f" ] && [ -d "${OUTPUT_DIR}" ]; then + echo "Removing existing directory: ${OUTPUT_DIR}" + rm -rf "${OUTPUT_DIR}" +fi + +# 4. Run the builder script if output directory does not exist. +if ! [ -d "${OUTPUT_DIR}" ]; then + echo "--- Building ${PRECISION} model ---" + python -m onnxruntime_genai.models.builder \ + -m ${HF_MODEL} \ + -p ${PRECISION} \ + -o ${OUTPUT_DIR} \ + -e cuda \ + -c ${CACHE_DIR} +fi + +# 5. Run the parity test +echo "--- Testing ${PRECISION} model parity ---" +python test_qwen_2.5_vl.py \ + --hf_model ${HF_MODEL} \ + --cache_dir ${CACHE_DIR} \ + --onnx_model ${ONNX_MODEL_PATH} \ + ${TEST_FLAG} + +echo "--- ${PRECISION} run complete ---" \ No newline at end of file diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py new file mode 100644 index 0000000000..b2d32afb01 --- /dev/null +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -0,0 +1,451 @@ +import os +import argparse +import torch +import numpy as np +import onnxruntime as ort +from onnx import TensorProto # Import TensorProto +# The modeling script is in the transformers library, so we import it +from transformers import Qwen2_5_VLForConditionalGeneration +from typing import Tuple, Dict, Any, List + +# --- Helper Functions --- + +def torch_dtype_to_onnx_tensor_proto(dtype: torch.dtype) -> int: + """Maps torch.dtype to onnx.TensorProto.DataType""" + if dtype == torch.float32: + return TensorProto.FLOAT + if dtype == torch.float16: + return TensorProto.FLOAT16 + if dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + if dtype == torch.int64: + return TensorProto.INT64 + if dtype == torch.int32: + return TensorProto.INT32 + if dtype == torch.bool: + return TensorProto.BOOL + raise ValueError(f"Unsupported torch dtype: {dtype}") + +def to_numpy(tensor): + """Move tensor to CPU and convert to numpy, handling bf16.""" + if tensor.dtype == torch.bfloat16: + # NumPy doesn't support bfloat16, so cast to float32 first + return tensor.detach().cpu().to(torch.float32).numpy() + return tensor.detach().cpu().numpy() + +def compare_outputs( + hf_logits: torch.Tensor, + ort_logits: torch.Tensor, # Changed to torch.Tensor + hf_presents: List[Tuple[torch.Tensor, torch.Tensor]], + ort_presents: List[torch.Tensor], # Changed to list[torch.Tensor] + step_name: str, + rtol: float, + atol: float +): + """Compares logits and KV cache outputs using numpy.""" + + print(f"--- Comparing {step_name} Logits ---") + + # We can use to_numpy safely here because we'll compare fp32 vs fp32 + # or (bf16->fp32) vs (bf16->fp32) + np.testing.assert_allclose( + to_numpy(hf_logits), + to_numpy(ort_logits), + rtol=rtol, + atol=atol + ) + print("Logits: PASS") + + print(f"\n--- Comparing {step_name} KV Cache ---") + # hf_presents is now a list of tuples: [(k0, v0), (k1, v1), ...] + # Flatten it to a list: [k0, v0, k1, v1, ...] + hf_presents_list = [t for layer_kv in hf_presents for t in layer_kv] + + assert len(hf_presents_list) == len(ort_presents), \ + f"HF presents count ({len(hf_presents_list)}) != ORT presents count ({len(ort_presents)})" + + for i in range(len(hf_presents_list)): + layer = i // 2 + kv_type = "key" if i % 2 == 0 else "value" + + hf_tensor = hf_presents_list[i] + ort_tensor = ort_presents[i] + + np.testing.assert_allclose( + to_numpy(hf_tensor), + to_numpy(ort_tensor), + rtol=rtol, + atol=atol + ) + print(f"KV Cache (all {len(hf_presents_list)} tensors): PASS") + print(f"\nāœ… {step_name} Parity Test Passed!\n") + +def ort_io_binding_helper( + sess: ort.InferenceSession, + input_tensors: Dict[str, torch.Tensor], + output_tensors: Dict[str, torch.Tensor], + device: str +) -> None: + """ + Binds torch tensors to an ONNX Runtime IOBinding object and runs the session. + Tensors must be on the correct device (e.g., 'cuda:0'). + """ + bind = sess.io_binding() + + # Get device type and index for ORT + ort_device = device.split(":")[0] + ort_device_id = 0 + if ":" in device: + ort_device_id = int(device.split(":")[1]) + + for name, tensor in input_tensors.items(): + if not tensor.is_contiguous(): + print(f"Warning: Input tensor {name} is not contiguous. Making it contiguous.") + tensor = tensor.contiguous() + input_tensors[name] = tensor # Update dict entry for future runs (decode) + + bind.bind_input( + name, + ort_device, + ort_device_id, + torch_dtype_to_onnx_tensor_proto(tensor.dtype), + tensor.shape, + tensor.data_ptr() + ) + + for name, tensor in output_tensors.items(): + if not tensor.is_contiguous(): + print(f"Warning: Output tensor {name} is not contiguous. Making it contiguous.") + tensor = tensor.contiguous() + output_tensors[name] = tensor # Update dict entry + + bind.bind_output( + name, + ort_device, + ort_device_id, + torch_dtype_to_onnx_tensor_proto(tensor.dtype), + tensor.shape, + tensor.data_ptr() + ) + + sess.run_with_iobinding(bind) + + +def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gpu: bool, use_bf16: bool, use_fp16: bool): + """ + Runs a two-step (prefill and decode) parity test between the Hugging Face + and ONNX models. + """ + + print(f"Loading Hugging Face model: {hf_model_name}") + print("This requires `trust_remote_code=True`.") + + if not use_gpu: + print("ERROR: This test script now requires a GPU (`--cpu` is not supported) due to IOBinding.") + return + + device = "cuda:0" # IOBinding needs the specific device ID + + if use_bf16: + torch_dtype = torch.bfloat16 + # Standard BF16 tolerances + rtol, atol = 2e-1, 1 + elif use_fp16: + torch_dtype = torch.float16 + # Standard FP16 tolerances + rtol, atol = 1e-1, 5e-1 + else: + torch_dtype = torch.float32 + # Standard FP32 tolerances + rtol, atol = 1e-1, 1e-1 + + allow_bf16_logits = os.getenv("allow_bf16_logits") in ["1", "true", "True"] + + if allow_bf16_logits: + logits_dtype = torch_dtype + else: + # The builder script (base.Model) upcasts logits to float32 + # ONLY when the io_dtype is bfloat16. + # For FP16 or FP32, it keeps the original dtype. + logits_dtype = torch.float32 if use_bf16 else torch_dtype + + print(f"Allocating ONNX logits output buffer with dtype: {logits_dtype}") + + hf_full_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + hf_model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + cache_dir=cache_dir + ).to(device).eval() + + # The ONNX model is *only* the language_model component + hf_text_model = hf_full_model.language_model + config = hf_text_model.config + + # Get model parameters + BATCH_SIZE = 1 + PREFILL_LEN = 10 + DECODE_LEN = 1 + HIDDEN_SIZE = config.hidden_size + NUM_LAYERS = config.num_hidden_layers + NUM_KV_HEADS = config.num_key_value_heads + HEAD_DIM = config.hidden_size // config.num_attention_heads + VOCAB_SIZE = config.vocab_size # Get vocab size for output + + print("\n--- Model Parameters ---") + print(f"Device: {device}") + print(f"DType: {torch_dtype}") + print(f"RTOL: {rtol}, ATOL: {atol}") + print(f"Layers: {NUM_LAYERS}") + print(f"Hidden Size: {HIDDEN_SIZE}") + print(f"KV Heads: {NUM_KV_HEADS}") + print(f"Head Dim: {HEAD_DIM}") + print("------------------------\n") + + print(f"Loading ONNX model: {onnx_model_path}") + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + sess = ort.InferenceSession(onnx_model_path, providers=providers) + + # Get all ONNX output names + output_names = [o.name for o in sess.get_outputs()] + + # ================================================================= + # 1. PREFILL STEP + # ================================================================= + print(f"Running Prefill Step (Sequence Length = {PREFILL_LEN})...") + + # --- Create HF/Torch Inputs --- + # Use randn (normal distribution) scaled down for better stability in FP16 + # inputs_embeds are normally centered around 0, unlike rand which is [0, 1] + inputs_embeds_prefill = torch.randn( + (BATCH_SIZE, PREFILL_LEN, HIDDEN_SIZE), + dtype=torch_dtype, + device=device + ) * 0.001 + + # Qwen2.5-VL uses 3D position IDs (temporal, height, width). + # For text tokens, all three dimensions typically use the same sequence index. + pos_ids_1d_prefill = torch.arange(PREFILL_LEN, device=device).expand(BATCH_SIZE, -1) + position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1) + + attention_mask_prefill = torch.ones( + (BATCH_SIZE, PREFILL_LEN), + dtype=torch.int64, + device=device + ) + + cache_position_prefill = torch.arange(PREFILL_LEN, device=device) + + # --- Create ONNX Input Tensors (on device) --- + ort_inputs_prefill = { + "inputs_embeds": inputs_embeds_prefill, + "position_ids": position_ids_prefill, + "attention_mask": attention_mask_prefill + } + + # Create dummy pasts with 0 sequence length + past_shape = (BATCH_SIZE, NUM_KV_HEADS, 0, HEAD_DIM) + dummy_past = torch.empty(past_shape, dtype=torch_dtype, device=device) + for i in range(NUM_LAYERS): + ort_inputs_prefill[f"past_key_values.{i}.key"] = dummy_past + ort_inputs_prefill[f"past_key_values.{i}.value"] = dummy_past + + # --- Create ONNX Output Tensors (on device) --- + ort_logits_prefill = torch.empty( + (BATCH_SIZE, PREFILL_LEN, VOCAB_SIZE), + dtype=logits_dtype, + device=device + ) + ort_presents_prefill = [] + ort_outputs_prefill = {"logits": ort_logits_prefill} + present_shape = (BATCH_SIZE, NUM_KV_HEADS, PREFILL_LEN, HEAD_DIM) + + for i in range(NUM_LAYERS): + ort_present_k = torch.empty(present_shape, dtype=torch_dtype, device=device) + ort_present_v = torch.empty(present_shape, dtype=torch_dtype, device=device) + ort_outputs_prefill[f"present.{i}.key"] = ort_present_k + ort_outputs_prefill[f"present.{i}.value"] = ort_present_v + ort_presents_prefill.extend([ort_present_k, ort_present_v]) + + # --- Run HF Model --- + with torch.no_grad(): + hf_outputs_prefill = hf_text_model( + inputs_embeds=inputs_embeds_prefill, + position_ids=position_ids_prefill, + attention_mask=attention_mask_prefill, + past_key_values=None, + cache_position=cache_position_prefill, + return_dict=True, + use_cache=True + ) + + # --- Run ONNX Model with IOBinding --- + ort_io_binding_helper(sess, ort_inputs_prefill, ort_outputs_prefill, device) + + # --- Compare Prefill --- + hf_logits_prefill = hf_full_model.lm_head(hf_outputs_prefill.last_hidden_state) + hf_presents_prefill = hf_outputs_prefill.past_key_values + + compare_outputs( + hf_logits_prefill, + ort_logits_prefill, # This is the tensor we pre-allocated + hf_presents_prefill, + ort_presents_prefill, # This is the list of tensors we pre-allocated + step_name="Prefill", + rtol=rtol, + atol=atol + ) + + # ================================================================= + # 2. DECODE STEP + # ================================================================= + print(f"Running Decode Step (Sequence Length = {DECODE_LEN})...") + + # --- Create HF/Torch Inputs --- + # Use randn (normal distribution) scaled down + inputs_embeds_decode = torch.randn( + (BATCH_SIZE, DECODE_LEN, HIDDEN_SIZE), + dtype=torch_dtype, + device=device + ) * 0.001 + + # Position IDs continue from prefill length + pos_ids_1d_decode = torch.tensor( + [[PREFILL_LEN]], + dtype=torch.int64, + device=device + ) + position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1) + + attention_mask_decode = torch.ones( + (BATCH_SIZE, PREFILL_LEN + DECODE_LEN), + dtype=torch.int64, + device=device + ) + + cache_position_decode = torch.tensor([PREFILL_LEN], device=device) + + # Use the KV cache from the HF prefill run + hf_past_key_values = hf_outputs_prefill.past_key_values + + # --- Create ONNX Input Tensors (on device) --- + ort_inputs_decode = { + "inputs_embeds": inputs_embeds_decode, + "position_ids": position_ids_decode, + "attention_mask": attention_mask_decode + } + + # Use the KV cache from the ONNX prefill run (these are already torch tensors) + for i in range(NUM_LAYERS): + ort_inputs_decode[f"past_key_values.{i}.key"] = ort_presents_prefill[i*2] + ort_inputs_decode[f"past_key_values.{i}.value"] = ort_presents_prefill[i*2 + 1] + + # --- Create ONNX Output Tensors (on device) --- + # --- FIX: Logits from bf16 ONNX model are intentionally float32 for accuracy --- + ort_logits_decode = torch.empty( + (BATCH_SIZE, DECODE_LEN, VOCAB_SIZE), + dtype=logits_dtype, + device=device + ) + ort_presents_decode = [] + ort_outputs_decode = {"logits": ort_logits_decode} + present_shape_decode = (BATCH_SIZE, NUM_KV_HEADS, PREFILL_LEN + DECODE_LEN, HEAD_DIM) + + for i in range(NUM_LAYERS): + ort_present_k = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) + ort_present_v = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) + ort_outputs_decode[f"present.{i}.key"] = ort_present_k + ort_outputs_decode[f"present.{i}.value"] = ort_present_v + ort_presents_decode.extend([ort_present_k, ort_present_v]) + + # --- Run HF Model --- + with torch.no_grad(): + hf_outputs_decode = hf_text_model( + inputs_embeds=inputs_embeds_decode, + position_ids=position_ids_decode, + attention_mask=attention_mask_decode, + past_key_values=hf_past_key_values, + cache_position=cache_position_decode, + return_dict=True, + use_cache=True + ) + + # --- Run ONNX Model with IOBinding --- + ort_io_binding_helper(sess, ort_inputs_decode, ort_outputs_decode, device) + + # --- Compare Decode --- + hf_logits_decode = hf_full_model.lm_head(hf_outputs_decode.last_hidden_state) + hf_presents_decode = hf_outputs_decode.past_key_values + + compare_outputs( + hf_logits_decode, + ort_logits_decode, + hf_presents_decode, + ort_presents_decode, + step_name="Decode", + rtol=rtol, + atol=atol + ) + + print("="*30) + print("šŸŽ‰ All Parity Tests Passed! šŸŽ‰") + print("="*30) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Parity test for Qwen 2.5 VL ONNX model.") + parser.add_argument( + "--hf_model", + type=str, + default="Qwen/Qwen2.5-VL-7B-Instruct", + help="Path or name of the Hugging Face model." + ) + parser.add_argument( + "--onnx_model", + type=str, + required=True, + help="Path to the exported ONNX model file." + ) + parser.add_argument( + "--cache_dir", + type=str, + default="./qwen2.5_vl_7b_instruct", + help="Path to the cache directory." + ) + + parser.add_argument( + "--cpu", + action="store_true", + help="Force running the test on CPU (Not supported with IOBinding)." + ) + + parser.add_argument( + "--bf16", + action="store_true", + help="Use bf16 precision." + ) + + parser.add_argument( + "--fp16", + action="store_true", + help="Use fp16 precision." + ) + + args = parser.parse_args() + + if args.cpu and (args.bf16 or args.fp16): + print("Warning: Cannot run bf16/fp16 on CPU. Forcing float32.") + args.bf16 = False + args.fp16 = False + + if args.cpu: + print("Warning: CPU testing with IOBinding is not set up. Forcing GPU.") + # This script is now GPU-only + + test_parity( + hf_model_name=args.hf_model, + cache_dir=args.cache_dir, + onnx_model_path=args.onnx_model, + use_gpu=True, # Forcing GPU + use_bf16=args.bf16, + use_fp16=args.fp16 + ) \ No newline at end of file From ec2626c004290f60d51e019faf7eba12093e23fa Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 01:05:59 +0000 Subject: [PATCH 02/18] add header --- test/python/models/qwen_2.5_vl/run.sh | 12 +++++++----- test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/python/models/qwen_2.5_vl/run.sh b/test/python/models/qwen_2.5_vl/run.sh index fb7e339581..9c3c60bbcb 100644 --- a/test/python/models/qwen_2.5_vl/run.sh +++ b/test/python/models/qwen_2.5_vl/run.sh @@ -1,10 +1,12 @@ -#!/bin/bash +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- -# --- -# This script builds and tests either an fp32, bf16 or fp16 model. Append -f to force export. -# +#!/bin/bash +# This script builds and tests either an fp32, bf16 or fp16 Qwen2.5-VL-3B-Instruct model. Append -f to force export. # Usage: ./run.sh [fp32|bf16|fp16] [-f] -# --- # Exit immediately if a command fails set -e diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index b2d32afb01..2cffe3ba10 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- import os import argparse import torch From 2e4f7289c2c61c176418ec8737d2e9159185c00d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 01:25:56 +0000 Subject: [PATCH 03/18] fix lint warnings --- src/python/py/models/builders/qwen.py | 7 ++----- .../python/models/qwen_2.5_vl/test_qwen_2.5_vl.py | 15 +++------------ 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index baba2bcffa..92891027a8 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -158,7 +158,7 @@ def make_inputs_and_outputs(self): self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"] # Call the base Model's make_inputs_and_outputs (skipping MistralModel's) - super(QwenModel, self).make_inputs_and_outputs() + super().make_inputs_and_outputs() def make_dynamic_rope_caches(self, layer_id, basename): """ @@ -459,7 +459,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn def make_attention(self, layer_id, attention, root_input, **kwargs): # 1. Unpack QKV if necessary (e.g. qkv_proj) - super(QwenModel, self).make_attention_unpacked(layer_id, attention, root_input, **kwargs) + super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) # 2. Build Q/K/V MatMul and Add nodes q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" @@ -613,6 +613,3 @@ def make_model(self, input_path, config=None): # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model print("Reading LM head") self.make_lm_head(hf_model.lm_head) - - del model - del hf_model diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index 2cffe3ba10..c0ab3c6d01 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -8,12 +8,9 @@ import torch import numpy as np import onnxruntime as ort -from onnx import TensorProto # Import TensorProto -# The modeling script is in the transformers library, so we import it +from onnx import TensorProto from transformers import Qwen2_5_VLForConditionalGeneration -from typing import Tuple, Dict, Any, List - -# --- Helper Functions --- +from typing import Tuple, Dict, List def torch_dtype_to_onnx_tensor_proto(dtype: torch.dtype) -> int: """Maps torch.dtype to onnx.TensorProto.DataType""" @@ -69,10 +66,7 @@ def compare_outputs( assert len(hf_presents_list) == len(ort_presents), \ f"HF presents count ({len(hf_presents_list)}) != ORT presents count ({len(ort_presents)})" - for i in range(len(hf_presents_list)): - layer = i // 2 - kv_type = "key" if i % 2 == 0 else "value" - + for i in range(len(hf_presents_list)): hf_tensor = hf_presents_list[i] ort_tensor = ort_presents[i] @@ -211,9 +205,6 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] sess = ort.InferenceSession(onnx_model_path, providers=providers) - # Get all ONNX output names - output_names = [o.name for o in sess.get_outputs()] - # ================================================================= # 1. PREFILL STEP # ================================================================= From afa98449d7e2918f795e8df51203923d5b93490d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 10:54:08 -0800 Subject: [PATCH 04/18] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/python/py/models/builders/qwen.py | 18 +++++++++--------- test/python/models/qwen_2.5_vl/run.sh | 3 +-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 92891027a8..9219f2e34a 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -5,11 +5,11 @@ # -------------------------------------------------------------------------- import os -from .base import Model # Changed this to match your new inheritance +from .base import Model import onnx_ir as ir import torch -class QwenModel(Model): # Changed this to match your new inheritance +class QwenModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -39,7 +39,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): config.rope_scaling = text_config_dict["rope_scaling"] # Need this for attention_scaling calculation if "original_max_position_embeddings" in text_config_dict: - config.original_max_position_embeddings = text_config_dict["original_max_position_embeddings"] + config.original_max_position_embeddings = text_config_dict["original_max_position_embeddings"] super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -98,7 +98,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): orig_max_pos = config.original_max_position_embeddings self.rope_attrs["attention_scaling"] = config.rope_scaling.get("attention_factor", torch.sqrt(1 + torch.log(torch.tensor(factor)) / torch.log(torch.tensor(orig_max_pos))).item()) else: - self.rope_attrs["attention_scaling"] = 1.0 + self.rope_attrs["attention_scaling"] = 1.0 # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False @@ -121,8 +121,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.mrope_splits = self.mrope_sections * 2 if sum(self.mrope_splits) != self.head_size: - # The sum (128) should now correctly match self.head_size (128) - raise ValueError(f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})") + # The sum (128) should now correctly match self.head_size (128) + raise ValueError(f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})") # Force GroupQueryAttention for fp32 cuda, # as base.py's make_attention_init doesn't include this combo. @@ -131,7 +131,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): print("Forcing GroupQueryAttention (GQA) for FP32 CUDA.") if self.attention_attrs["op_type"] != "GroupQueryAttention": - raise ValueError(f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo.") + raise ValueError(f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo.") # Create and save the inv_freq tensor self.make_inv_freq_tensor() @@ -146,8 +146,8 @@ def make_inv_freq_tensor(self): # The HF model expects H/2, not R/2 if dim != self.head_size: - print(f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported.") - inv_freq = inv_freq[:(self.head_size // 2)] + print(f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported.") + inv_freq = inv_freq[:(self.head_size // 2)] self.make_initializer(inv_freq, "model.inv_freq", to=ir.DataType.FLOAT) print("Created and saved 'model.inv_freq' initializer.") diff --git a/test/python/models/qwen_2.5_vl/run.sh b/test/python/models/qwen_2.5_vl/run.sh index 9c3c60bbcb..9eee36e151 100644 --- a/test/python/models/qwen_2.5_vl/run.sh +++ b/test/python/models/qwen_2.5_vl/run.sh @@ -1,10 +1,9 @@ +#!/bin/bash # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - -#!/bin/bash # This script builds and tests either an fp32, bf16 or fp16 Qwen2.5-VL-3B-Instruct model. Append -f to force export. # Usage: ./run.sh [fp32|bf16|fp16] [-f] From 89632a22ae17850993bb0c117557e517983822e6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 19:13:51 +0000 Subject: [PATCH 05/18] format --- src/python/py/models/builder.py | 26 +- src/python/py/models/builders/__init__.py | 54 +- src/python/py/models/builders/base.py | 2287 +++++++++++++---- src/python/py/models/builders/phi.py | 280 +- src/python/py/models/builders/qwen.py | 515 +++- src/python/py/models/test_vl.py | 61 + .../models/qwen_2.5_vl/test_qwen_2.5_vl.py | 270 +- 7 files changed, 2671 insertions(+), 822 deletions(-) create mode 100644 src/python/py/models/test_vl.py diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 3f9e982b49..b4b203618e 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -14,10 +14,6 @@ import onnx_ir as ir import torch -from transformers import ( - AutoConfig, -) - from builders import ( ChatGLMModel, ErnieModel, @@ -39,11 +35,14 @@ Phi3VModel, Phi4MMModel, PhiModel, - Qwen25VLTextModel, Qwen3Model, + Qwen25VLTextModel, QwenModel, SmolLM3Model, ) +from transformers import ( + AutoConfig, +) def check_extra_options(kv_pairs, execution_provider): @@ -162,7 +161,15 @@ def set_onnx_dtype(precision: str, extra_options: dict[str, Any]) -> ir.DataType @torch.no_grad -def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options): +def create_model( + model_name, + input_path, + output_dir, + precision, + execution_provider, + cache_dir, + **extra_options, +): if execution_provider == "NvTensorRtRtx": execution_provider = "trt-rtx" extra_options["use_qdq"] = True @@ -182,7 +189,10 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid from peft import PeftConfig peft_config = PeftConfig.from_pretrained( - extra_options["adapter_path"], token=hf_token, trust_remote_code=hf_remote, **extra_kwargs + extra_options["adapter_path"], + token=hf_token, + trust_remote_code=hf_remote, + **extra_kwargs, ) config.update(peft_config.__dict__) @@ -307,7 +317,7 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid if not config_only: # Make ONNX model - onnx_model.make_model(input_path, config) + onnx_model.make_model(input_path) # Save ONNX model onnx_model.save_model(output_dir) diff --git a/src/python/py/models/builders/__init__.py b/src/python/py/models/builders/__init__.py index 6f606fd5ae..cc41f1182e 100644 --- a/src/python/py/models/builders/__init__.py +++ b/src/python/py/models/builders/__init__.py @@ -4,27 +4,51 @@ # license information. # -------------------------------------------------------------------------- from .base import Model +from .chatglm import ChatGLMModel +from .ernie import ErnieModel +from .gemma import Gemma2Model, Gemma3Model, GemmaModel +from .gptoss import GPTOSSModel +from .granite import GraniteModel from .llama import LlamaModel from .mistral import MistralModel -from .qwen import QwenModel, Qwen3Model, Qwen25VLTextModel -from .phi import ( - PhiModel, Phi3MiniModel, Phi3MiniLongRoPEModel, Phi3SmallModel, - Phi3SmallLongRoPEModel, Phi3VModel, Phi3MoELongRoPEModel, Phi4MMModel -) -from .gemma import GemmaModel, Gemma2Model, Gemma3Model from .nemotron import NemotronModel -from .chatglm import ChatGLMModel from .olmo import OLMoModel -from .granite import GraniteModel -from .ernie import ErnieModel +from .phi import ( + Phi3MiniLongRoPEModel, + Phi3MiniModel, + Phi3MoELongRoPEModel, + Phi3SmallLongRoPEModel, + Phi3SmallModel, + Phi3VModel, + Phi4MMModel, + PhiModel, +) +from .qwen import Qwen3Model, Qwen25VLTextModel, QwenModel from .smollm import SmolLM3Model -from .gptoss import GPTOSSModel __all__ = [ + "ChatGLMModel", + "ErnieModel", + "GPTOSSModel", + "Gemma2Model", + "Gemma3Model", + "GemmaModel", + "GraniteModel", + "LlamaModel", + "MistralModel", "Model", - "LlamaModel", "MistralModel", "QwenModel", "Qwen3Model", "Qwen25VLTextModel", "PhiModel", - "Phi3MiniModel", "Phi3MiniLongRoPEModel", "Phi3SmallModel", - "Phi3SmallLongRoPEModel", "Phi3VModel", "Phi3MoELongRoPEModel", "Phi4MMModel", - "GemmaModel", "Gemma2Model", "Gemma3Model", "NemotronModel", "ChatGLMModel", - "OLMoModel", "GraniteModel", "ErnieModel", "SmolLM3Model", "GPTOSSModel" + "NemotronModel", + "OLMoModel", + "Phi3MiniLongRoPEModel", + "Phi3MiniModel", + "Phi3MoELongRoPEModel", + "Phi3SmallLongRoPEModel", + "Phi3SmallModel", + "Phi3VModel", + "Phi4MMModel", + "PhiModel", + "Qwen3Model", + "Qwen25VLTextModel", + "QwenModel", + "SmolLM3Model", ] diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index b856b2eb2f..95cd4d97e2 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -10,12 +10,12 @@ import ast import json import os -from typing import Sequence +from collections.abc import Sequence import numpy as np import onnx_ir as ir import torch -from onnx_ir.tensor_adapters import to_torch_dtype, TorchTensor +from onnx_ir.tensor_adapters import TorchTensor, to_torch_dtype from onnxruntime.quantization.matmul_nbits_quantizer import ( MatMulNBitsQuantizer, QuantFormat, @@ -29,6 +29,7 @@ GenerationConfig, ) + def parse_hf_token(hf_token): """ Returns the authentication token needed for Hugging Face. @@ -45,19 +46,51 @@ def parse_hf_token(hf_token): # Return user-provided token as string return hf_token + class Model: def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + self.config = config self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings - self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else self.context_length - self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel - self.intermediate_size = config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size + self.original_context_length = ( + config.original_max_position_embeddings + if hasattr(config, "original_max_position_embeddings") + else config.rope_scaling["original_max_position_embeddings"] + if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") + else self.context_length + ) + self.window_size = ( + config.sliding_window if hasattr(config, "sliding_window") else -1 + ) # default is -1 in GroupQueryAttention kernel + self.intermediate_size = ( + config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size + ) self.hidden_size = config.hidden_size - self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.multi_query_group_num if hasattr(config, "multi_query_group_num") else config.num_attention_heads + self.num_kv_heads = ( + config.num_key_value_heads + if hasattr(config, "num_key_value_heads") + else config.multi_query_group_num + if hasattr(config, "multi_query_group_num") + else config.num_attention_heads + ) self.num_attn_heads = config.num_attention_heads - self.head_size = config.head_dim if hasattr(config, "head_dim") and config.head_dim is not None else config.hidden_size // config.num_attention_heads - self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers if hasattr(config, "num_hidden_layers") else config.num_layers + self.head_size = ( + config.head_dim + if hasattr(config, "head_dim") and config.head_dim is not None + else config.hidden_size // config.num_attention_heads + ) + self.num_layers = ( + int(extra_options["num_hidden_layers"]) + if "num_hidden_layers" in extra_options + else config.num_hidden_layers + if hasattr(config, "num_hidden_layers") + else config.num_layers + ) self.vocab_size = config.vocab_size - self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act + self.activation = ( + config.hidden_activation + if hasattr(config, "hidden_activation") and config.hidden_activation is not None + else config.hidden_act + ) self.model_name_or_path = config._name_or_path self.model_type = config.architectures[0] @@ -88,32 +121,51 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.ep_attrs = { "cpu": {}, "cuda": { - "enable_cuda_graph": "1" if extra_options.get("enable_cuda_graph", False) else "0", # "1" if the model is able to enable cuda graph, "0" otherwise - "enable_skip_layer_norm_strict_mode": "1" + "enable_cuda_graph": "1" + if extra_options.get("enable_cuda_graph", False) + else "0", # "1" if the model is able to enable cuda graph, "0" otherwise + "enable_skip_layer_norm_strict_mode": "1", }, "dml": {}, # TODO: Enable graph capture for webgpu once supported both in onnxruntime-genai and onnxruntime. "webgpu": {}, - "trt-rtx": {"enable_cuda_graph": "1"} + "trt-rtx": {"enable_cuda_graph": "1"}, } # Map input names to their types and shapes self.input_names = ["input_ids", "attention_mask", "position_ids"] self.input_types = { - "input_ids": ir.DataType.INT64, # For standard models - "attention_mask": ir.DataType.INT64, # For standard models - "position_ids": ir.DataType.INT64, # For standard models - "inputs_embeds": self.io_dtype, # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) - "past_key_values.key": self.io_dtype, # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) - "past_key_values.value": self.io_dtype, # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) + "input_ids": ir.DataType.INT64, # For standard models + "attention_mask": ir.DataType.INT64, # For standard models + "position_ids": ir.DataType.INT64, # For standard models + "inputs_embeds": self.io_dtype, # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) + "past_key_values.key": self.io_dtype, # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) + "past_key_values.value": self.io_dtype, # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) } self.input_shapes = { - "input_ids": ["batch_size", "sequence_length"], # For standard models - "attention_mask": ["batch_size", "total_sequence_length"], # For standard models - "position_ids": ["batch_size", "sequence_length"], # For standard models - "inputs_embeds": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) - "past_key_values.key": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) - "past_key_values.value": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) + "input_ids": ["batch_size", "sequence_length"], # For standard models + "attention_mask": [ + "batch_size", + "total_sequence_length", + ], # For standard models + "position_ids": ["batch_size", "sequence_length"], # For standard models + "inputs_embeds": [ + "batch_size", + "sequence_length", + self.hidden_size, + ], # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) + "past_key_values.key": [ + "batch_size", + self.num_kv_heads, + "past_sequence_length", + self.head_size, + ], # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) + "past_key_values.value": [ + "batch_size", + self.num_kv_heads, + "past_sequence_length", + self.head_size, + ], # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) } self.exclude_embeds = extra_options.get("exclude_embeds", False) if self.exclude_embeds: @@ -122,16 +174,34 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Map output names to their types and shapes self.output_names = ["logits"] self.output_types = { - "hidden_states": self.io_dtype, # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) - "logits": self.io_dtype, # For standard models - "present.key": self.io_dtype, # For standard models (note that `present.key` is written this way to match Hugging Face format) - "present.value": self.io_dtype, # For standard models (note that `present.value` is written this way to match Hugging Face format) + "hidden_states": self.io_dtype, # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) + "logits": self.io_dtype, # For standard models + "present.key": self.io_dtype, # For standard models (note that `present.key` is written this way to match Hugging Face format) + "present.value": self.io_dtype, # For standard models (note that `present.value` is written this way to match Hugging Face format) } self.output_shapes = { - "hidden_states": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) - "logits": ["batch_size", "sequence_length", self.vocab_size], # For standard models - "present.key": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.key` is written this way to match Hugging Face format) - "present.value": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.value` is written this way to match Hugging Face format) + "hidden_states": [ + "batch_size", + "sequence_length", + self.hidden_size, + ], # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) + "logits": [ + "batch_size", + "sequence_length", + self.vocab_size, + ], # For standard models + "present.key": [ + "batch_size", + self.num_kv_heads, + "total_sequence_length", + self.head_size, + ], # For standard models (note that `present.key` is written this way to match Hugging Face format) + "present.value": [ + "batch_size", + self.num_kv_heads, + "total_sequence_length", + self.head_size, + ], # For standard models (note that `present.value` is written this way to match Hugging Face format) } self.make_outputs_init() @@ -141,104 +211,118 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Mask-specific variables # TODO: Reconcile differences between `seqlens_k` and `key_total_seq_lens` in the GroupQueryAttention and SparseAttention implementations. Ideally the same subgraph can be shared for both. self.mask_attrs = { - "mask_name": "", # Name of node that outputs 4D causal attention mask (used as add_qk in MultiHeadAttention) - "seqlens_k": "", # Sum of each row in attention mask - 1 (used as input to GroupQueryAttention) - "total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention and SparseAttention) - "block_row_indices": "", # Row indices of CSR format of block mask (used as input to SparseAttention) - "block_col_indices": "", # Col indices of CSR format of block mask (used as input to SparseAttention) - "key_total_seq_lens": "", # Sum of each row in attention mask (used as input to SparseAttention) + "mask_name": "", # Name of node that outputs 4D causal attention mask (used as add_qk in MultiHeadAttention) + "seqlens_k": "", # Sum of each row in attention mask - 1 (used as input to GroupQueryAttention) + "total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention and SparseAttention) + "block_row_indices": "", # Row indices of CSR format of block mask (used as input to SparseAttention) + "block_col_indices": "", # Col indices of CSR format of block mask (used as input to SparseAttention) + "key_total_seq_lens": "", # Sum of each row in attention mask (used as input to SparseAttention) } # Embedding-specific variables self.embed_attrs = { - "scale": 1, # Scale value to multiply output of Embedding layer by + "scale": 1, # Scale value to multiply output of Embedding layer by } # LayerNorm-specific variables epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06 self.layernorm_attrs = { - "simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm - "first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms - "last_layernorm": False, # Last LayerNorm = SkipLayerNorm with only output 0 (no output 3) - "root_input": "", # Root input from parent node for LayerNorm and SkipLayerNorm - "skip_input": "", # Skip input from parent node for SkipLayerNorm - "output_0": "", # Output 0 for LayerNorm and SkipLayerNorm - "output_3": "", # Output 3 for SkipLayerNorm - "add_offset": 0, # Offset value for LayerNorm weight - "epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm - "cast": { # Casting LayerNorm-specific variables - "use_fp32": False, # Use float32 precision to compute LayerNorm - "root_input": False, # Cast root_input - "skip_input": False, # Cast skip_input - "output_0": False, # Cast output_0 - "output_3": False, # Cast output_3 - } + "simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm + "first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms + "last_layernorm": False, # Last LayerNorm = SkipLayerNorm with only output 0 (no output 3) + "root_input": "", # Root input from parent node for LayerNorm and SkipLayerNorm + "skip_input": "", # Skip input from parent node for SkipLayerNorm + "output_0": "", # Output 0 for LayerNorm and SkipLayerNorm + "output_3": "", # Output 3 for SkipLayerNorm + "add_offset": 0, # Offset value for LayerNorm weight + "epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm + "cast": { # Casting LayerNorm-specific variables + "use_fp32": False, # Use float32 precision to compute LayerNorm + "root_input": False, # Cast root_input + "skip_input": False, # Cast skip_input + "output_0": False, # Cast output_0 + "output_3": False, # Cast output_3 + }, } # MatMul-specific variables is_lora = hasattr(config, "peft_type") and config.peft_type == "LORA" self.matmul_attrs = { - "use_lora": is_lora, # Use LoRA/QLoRA format + "use_lora": is_lora, # Use LoRA/QLoRA format } # RotaryEmbedding-specific variables position_scale = config.rope_position_scale if hasattr(config, "rope_position_scale") else 1 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 rotemb_dim = int(self.head_size * partial_rotary_factor) if partial_rotary_factor != 1.0 else 0 - rope_theta = config.rope_theta if hasattr(config, "rope_theta") else config.rope_embedding_base if hasattr(config, "rope_embedding_base") else 10000 + rope_theta = ( + config.rope_theta + if hasattr(config, "rope_theta") + else config.rope_embedding_base + if hasattr(config, "rope_embedding_base") + else 10000 + ) self.rope_attrs = { - "create_caches": True, # Create cos/sin caches for rotary embeddings - "save_caches": True, # Auto-save cos/sin caches for rotary embeddings after creation - "cache_length": self.context_length, # Cache length to use when creating cos/sin caches for rotary embeddings - "theta": rope_theta, # Base value if calculating cos/sin caches from scratch + "create_caches": True, # Create cos/sin caches for rotary embeddings + "save_caches": True, # Auto-save cos/sin caches for rotary embeddings after creation + "cache_length": self.context_length, # Cache length to use when creating cos/sin caches for rotary embeddings + "theta": rope_theta, # Base value if calculating cos/sin caches from scratch "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings - "interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0) - "rotary_embedding_dim": rotemb_dim, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) - "rescale_factors": 1, # Rescale factors when calculating `inv_freq` in rotary embeddings - "t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings - "position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings - "mscale": 1, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings - "mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings + "interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0) + "rotary_embedding_dim": rotemb_dim, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "rescale_factors": 1, # Rescale factors when calculating `inv_freq` in rotary embeddings + "t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings + "position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings + "mscale": 1, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings } if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.make_rope_init(config) # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) - attn_softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") and config.attn_logit_softcapping is not None else 0.0 # default is 0.0 in GroupQueryAttention kernel + attn_softcap = ( + config.attn_logit_softcapping + if hasattr(config, "attn_logit_softcapping") and config.attn_logit_softcapping is not None + else 0.0 + ) # default is 0.0 in GroupQueryAttention kernel # Block-sparse attention-specific variables sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0 - kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0 + kernel_block_size = ( + config.blocksparse_triton_kernel_block_size + if hasattr(config, "blocksparse_triton_kernel_block_size") + else 0 + ) local_blocks = config.blocksparse_num_local_blocks if hasattr(config, "blocksparse_num_local_blocks") else 0 vert_block_stride = config.blocksparse_vert_stride if hasattr(config, "blocksparse_vert_stride") else 0 homo_head = config.blocksparse_homo_head_pattern if hasattr(config, "blocksparse_homo_head_pattern") else False self.attention_attrs = { - "q_path": "", # Q path to attention - "k_path": "", # K path to attention - "v_path": "", # V path to attention - "op_type": "MultiHeadAttention", # Attention op to use - "scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention - "softcap": attn_softcap, # Softcap value to prevent values from exploding in attention - "use_rope_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op) - "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) - "block_sparse": { # Block-sparse attention-specific variables - "sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op - "kernel_block_size": kernel_block_size, # Kernel block size for sparse attention - "local_blocks": local_blocks, # Number of local blocks for sparse attention - "vert_stride": vert_block_stride, # Vertical stride to use for sparse attention - "homo_head": homo_head, # Use homo head pattern for sparse attention + "q_path": "", # Q path to attention + "k_path": "", # K path to attention + "v_path": "", # V path to attention + "op_type": "MultiHeadAttention", # Attention op to use + "scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention + "softcap": attn_softcap, # Softcap value to prevent values from exploding in attention + "use_rope_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op) + "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) + "block_sparse": { # Block-sparse attention-specific variables + "sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op + "kernel_block_size": kernel_block_size, # Kernel block size for sparse attention + "local_blocks": local_blocks, # Number of local blocks for sparse attention + "vert_stride": vert_block_stride, # Vertical stride to use for sparse attention + "homo_head": homo_head, # Use homo head pattern for sparse attention }, - "q_norm": False, # LayerNorm after MatMul in Q path - "k_norm": False, # LayerNorm after MatMul in K path - "sinks": False, # Sink values for softmax in attention + "q_norm": False, # LayerNorm after MatMul in Q path + "k_norm": False, # LayerNorm after MatMul in K path + "sinks": False, # Sink values for softmax in attention } self.make_attention_init() # MLP-specific variables self.mlp_attrs = { - "use_proj": True, # Use projection style for MLP (GateProj/UpProj/DownProj) - "use_fc": False, # Use fully-connected style for MLP (FC1/FC2) - "output_0": "", # Output 0 for MLP layer + "use_proj": True, # Use projection style for MLP (GateProj/UpProj/DownProj) + "use_fc": False, # Use fully-connected style for MLP (FC1/FC2) + "output_0": "", # Output 0 for MLP layer } # MoE-specific variables @@ -248,24 +332,28 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): expert_weight_bits = 8 if extra_options.get("use_8bits_moe", False) else 4 swiglu_limit = config.swiglu_limit if hasattr(config, "swiglu_limit") else None self.moe_attrs = { - "op_type": moe_op_type, # MoE op to use - "num_experts": num_experts, # Number of experts in MoE layer - "top_k": top_k_experts, # Number of experts to select in MoE layer - "activation_alpha": 1.0, # Alpha parameter used in activation function - "activation_beta": 0.0, # Beta parameter used in activation function - "activation_type": self.activation, # Activation function for MoE layer - "expert_weight_bits": expert_weight_bits, # Number of bits used in quantized MoE weights (only INT4 or INT8 are supported). - "normalize_routing_weights": False, # Normalize routing weights in MoE layer - "swiglu_fusion": 0, # Fusion level for SwiGLU activation function - "swiglu_limit": swiglu_limit, # Value used to clamp results into a certain range in SwiGLU activation function - "use_sparse_mixer": False, # Use SparseMixer in MoE layer (used in Phi-3.5 MoE) + "op_type": moe_op_type, # MoE op to use + "num_experts": num_experts, # Number of experts in MoE layer + "top_k": top_k_experts, # Number of experts to select in MoE layer + "activation_alpha": 1.0, # Alpha parameter used in activation function + "activation_beta": 0.0, # Beta parameter used in activation function + "activation_type": self.activation, # Activation function for MoE layer + "expert_weight_bits": expert_weight_bits, # Number of bits used in quantized MoE weights (only INT4 or INT8 are supported). + "normalize_routing_weights": False, # Normalize routing weights in MoE layer + "swiglu_fusion": 0, # Fusion level for SwiGLU activation function + "swiglu_limit": swiglu_limit, # Value used to clamp results into a certain range in SwiGLU activation function + "use_sparse_mixer": False, # Use SparseMixer in MoE layer (used in Phi-3.5 MoE) } # LM head-specific variables - lm_head_softcap = config.final_logit_softcapping if hasattr(config, "final_logit_softcapping") and config.final_logit_softcapping is not None else 0.0 # default is 0.0 in GroupQueryAttention kernel + lm_head_softcap = ( + config.final_logit_softcapping + if hasattr(config, "final_logit_softcapping") and config.final_logit_softcapping is not None + else 0.0 + ) # default is 0.0 in GroupQueryAttention kernel self.lm_head_attrs = { - "scale": 1, # Scale value to multiply output of LM head by - "mask": None, # LM head mask for tokens in the vocabulary + "scale": 1, # Scale value to multiply output of LM head by + "mask": None, # LM head mask for tokens in the vocabulary "softcap": lm_head_softcap, # Softcap value to prevent values from exploding in LM head } if hasattr(config, "dummy_token_indices"): @@ -279,10 +367,12 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.int4_block_size = extra_options.get("int4_block_size", 32) self.quant_attrs = { "int4": { - "accuracy_level": int(extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0)), + "accuracy_level": int( + extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0) + ), "block_size": int(self.int4_block_size), "is_symmetric": extra_options.get("int4_is_symmetric", True), - "op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul", )), + "op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul",)), "nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []), "algo_config": int4_algo_config, }, @@ -291,11 +381,20 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): if self.quant_type is not None: # Create quantized attributes from quantization config self.quant_attrs["config"] = config.quantization_config - self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False + self.quant_attrs["use_g_idx"] = ( + config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False + ) - self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False + self.int4_tied_embeddings = ( + config.tie_word_embeddings + if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None + else False + ) self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings) - self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"} + self.int8_lm_head = extra_options.get("int4_algo_config", "default") in { + "k_quant_mixed", + "k_quant_last", + } if not self.int8_lm_head: # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match. self.int4_tied_embeddings = False @@ -325,26 +424,34 @@ def make_rope_init(self, config): short_mscale = config.rope_scaling["short_mscale"] if "short_mscale" in config.rope_scaling else 0 long_mscale = config.rope_scaling["long_mscale"] if "long_mscale" in config.rope_scaling else 0 - short_mscale = short_mscale if short_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length) - long_mscale = long_mscale if long_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length) + short_mscale = ( + short_mscale + if short_mscale > 0 + else self.make_mscale(self.context_length / self.original_context_length) + ) + long_mscale = ( + long_mscale if long_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length) + ) self.rope_attrs["multi_cache"] = { - "short_factor": short_factor, # Short factor when calculating `inv_freq` in rotary embeddings - "long_factor": long_factor, # Long factor when calculating `inv_freq` in rotary embeddings - "short_mscale": short_mscale, # Magnitude scaling for short factor when scaling `emb.cos()/emb.sin()` in rotary embeddings - "long_mscale": long_mscale, # Magnitude scaling for long factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "short_factor": short_factor, # Short factor when calculating `inv_freq` in rotary embeddings + "long_factor": long_factor, # Long factor when calculating `inv_freq` in rotary embeddings + "short_mscale": short_mscale, # Magnitude scaling for short factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "long_mscale": long_mscale, # Magnitude scaling for long factor when scaling `emb.cos()/emb.sin()` in rotary embeddings } elif "low_freq_factor" in config.rope_scaling: # For models that rescale `inv_freq` using `low_freq_factor` and `high_freq_factor` (e.g. LLaMA-3.1) factor = config.rope_scaling["factor"] if "factor" in config.rope_scaling else 0 low_freq_factor = config.rope_scaling["low_freq_factor"] if "low_freq_factor" in config.rope_scaling else 0 - high_freq_factor = config.rope_scaling["high_freq_factor"] if "high_freq_factor" in config.rope_scaling else 0 - + high_freq_factor = ( + config.rope_scaling["high_freq_factor"] if "high_freq_factor" in config.rope_scaling else 0 + ) + self.rope_attrs["rescale_inv_freq"] = { - "factor": factor, # Scale factor when calculating `new_freq` in rotary embeddings - "low_freq_factor": low_freq_factor, # Low freq factor when calculating `low_freq_wavelen` in rotary embeddings - "high_freq_factor": high_freq_factor, # High freq factor when calculating `high_freq_wavelen` in rotary embeddings + "factor": factor, # Scale factor when calculating `new_freq` in rotary embeddings + "low_freq_factor": low_freq_factor, # Low freq factor when calculating `low_freq_wavelen` in rotary embeddings + "high_freq_factor": high_freq_factor, # High freq factor when calculating `high_freq_wavelen` in rotary embeddings } elif "beta_fast" in config.rope_scaling: @@ -400,10 +507,20 @@ def make_attention_init(self): def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): # Create config with attributes from config.json and generation_config.json (if latter file exists) - config = AutoConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs) + config = AutoConfig.from_pretrained( + model_name_or_path, + token=self.hf_token, + trust_remote_code=self.hf_remote, + **extra_kwargs, + ) try: # Override search attributes in config based on values in generation_config.json - gen_config = GenerationConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs) + gen_config = GenerationConfig.from_pretrained( + model_name_or_path, + token=self.hf_token, + trust_remote_code=self.hf_remote, + **extra_kwargs, + ) defaults = { "bos_token_id": None, "do_sample": False, @@ -420,31 +537,41 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): except: pass - inputs = dict(zip(self.input_names, self.input_names)) - inputs.update({ - "past_key_names": "past_key_values.%d.key", - "past_value_names": "past_key_values.%d.value", - }) - outputs = dict(zip(self.output_names, self.output_names)) - outputs.update({ - "present_key_names": "present.%d.key", - "present_value_names": "present.%d.value", - }) + inputs = dict(zip(self.input_names, self.input_names, strict=False)) + inputs.update( + { + "past_key_names": "past_key_values.%d.key", + "past_value_names": "past_key_values.%d.value", + } + ) + outputs = dict(zip(self.output_names, self.output_names, strict=False)) + outputs.update( + { + "present_key_names": "present.%d.key", + "present_value_names": "present.%d.value", + } + ) if "hidden_states" in outputs: # Remove 'hidden_states' from 'outputs' entry in config since ORT GenAI doesn't use it del outputs["hidden_states"] bos_token_id = config.bos_token_id if hasattr(config, "bos_token_id") and config.bos_token_id is not None else 1 eos_token_id = config.eos_token_id - pad_token_id = config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id is not None else config.eos_token_id[0] if isinstance(config.eos_token_id, list) else config.eos_token_id + pad_token_id = ( + config.pad_token_id + if hasattr(config, "pad_token_id") and config.pad_token_id is not None + else config.eos_token_id[0] + if isinstance(config.eos_token_id, list) + else config.eos_token_id + ) genai_config = { "model": { "bos_token_id": bos_token_id, "context_length": self.context_length, "decoder": { - "session_options" : { + "session_options": { "log_id": "onnxruntime-genai", - "provider_options" : [], + "provider_options": [], }, "filename": self.filename, "head_size": self.head_size, @@ -457,7 +584,9 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): }, "eos_token_id": eos_token_id, "pad_token_id": pad_token_id, - "type": self.model_type[ : self.model_type.find("For") if "For" in self.model_type else len(self.model_type)].lower(), + "type": self.model_type[ + : self.model_type.find("For") if "For" in self.model_type else len(self.model_type) + ].lower(), "vocab_size": self.vocab_size, }, "search": { @@ -470,7 +599,9 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "no_repeat_ngram_size": config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0, "num_beams": config.num_beams if hasattr(config, "num_beams") else 1, "num_return_sequences": config.num_return_sequences if hasattr(config, "num_return_sequences") else 1, - "past_present_share_buffer": False if "config_only" in self.extra_options else self.past_present_share_buffer, + "past_present_share_buffer": False + if "config_only" in self.extra_options + else self.past_present_share_buffer, "repetition_penalty": config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0, "temperature": config.temperature if hasattr(config, "temperature") else 1.0, "top_k": config.top_k if hasattr(config, "top_k") else 50, @@ -480,22 +611,24 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0: # Compute layer indices that use sliding window attention - layer_idxs = [layer_id for layer_id in range(self.num_layers) if hasattr(self, "is_local") and self.is_local(layer_id)] - + layer_idxs = [ + layer_id for layer_id in range(self.num_layers) if hasattr(self, "is_local") and self.is_local(layer_id) + ] + genai_config["model"]["decoder"]["sliding_window"] = { "window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False, - "layers": layer_idxs + "layers": layer_idxs, } if self.ep != "cpu": ep_name = self.ep.replace("trt-rtx", "NvTensorRtRtx") - ep_options = { ep_name : self.ep_attrs[self.ep] } + ep_options = {ep_name: self.ep_attrs[self.ep]} genai_config["model"]["decoder"]["session_options"]["provider_options"].append(ep_options) print(f"Saving GenAI config in {out_dir}") - with open(os.path.join(out_dir,"genai_config.json"), "w") as f: + with open(os.path.join(out_dir, "genai_config.json"), "w") as f: json.dump(genai_config, f, indent=4) def make_key_value_cache_shape(self, layer_id, shape): @@ -504,11 +637,21 @@ def make_key_value_cache_shape(self, layer_id, shape): For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name. """ if self.ep == "trt-rtx" and hasattr(self, "is_local") and self.is_local(layer_id): - return [shape[0], shape[1], shape[2].replace("sequence", "sliding"), shape[3]] + return [ + shape[0], + shape[1], + shape[2].replace("sequence", "sliding"), + shape[3], + ] return shape def save_processing(self, model_name_or_path, extra_kwargs, out_dir): - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + token=self.hf_token, + trust_remote_code=self.hf_remote, + **extra_kwargs, + ) print(f"Saving processing files in {out_dir} for GenAI") tokenizer.save_pretrained(out_dir) @@ -520,7 +663,9 @@ def make_int4_algo_config(self, quant_method: str): int4_algo_config = RTNWeightOnlyQuantConfig() elif quant_method in {"k_quant_mixed", "k_quant_last"}: - from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig + from onnxruntime.quantization.matmul_nbits_quantizer import ( + KQuantWeightOnlyQuantConfig, + ) if quant_method == "k_quant_mixed": # k_quant_mixed is from llama.cpp. @@ -529,7 +674,9 @@ def make_int4_algo_config(self, quant_method: str): layers_to_exclude = [ i for i in range(self.num_layers) - if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2 + if i < self.num_layers / 8 + or i >= 7 * self.num_layers / 8 + or (i - (round)(self.num_layers / 8)) % 3 == 2 ] for i in layers_to_exclude: customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8} @@ -558,7 +705,9 @@ def to_int4(self) -> ir.Model: def save_model(self, out_dir): print(f"Saving ONNX model in {out_dir}") - already_quantized_in_qdq_format = self.quant_type is not None and self.quant_attrs["use_qdq"] # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path + already_quantized_in_qdq_format = ( + self.quant_type is not None and self.quant_attrs["use_qdq"] + ) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path if self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and not already_quantized_in_qdq_format: model = self.to_int4() else: @@ -601,7 +750,13 @@ def callback(tensor: ir.TensorProtocol, metadata: dict): if not os.listdir(self.cache_dir): os.rmdir(self.cache_dir) - def make_initializer(self, tensor: torch.Tensor | np.ndarray | ir.TensorProtocol, /, name: str, to: ir.DataType | None = None): + def make_initializer( + self, + tensor: torch.Tensor | np.ndarray | ir.TensorProtocol, + /, + name: str, + to: ir.DataType | None = None, + ): if to is not None: # Cast the tensor lazily if `to` is provided def tensor_func(): @@ -609,9 +764,7 @@ def tensor_func(): tensor = tensor.to(to_torch_dtype(to)) return TorchTensor(tensor, name=name) - ir_tensor = ir.LazyTensor( - tensor_func, dtype=to, shape=ir.Shape(tensor.shape), name=name - ) + ir_tensor = ir.LazyTensor(tensor_func, dtype=to, shape=ir.Shape(tensor.shape), name=name) elif isinstance(tensor, torch.nn.parameter.Parameter): ir_tensor = TorchTensor(tensor, name=name) else: @@ -620,7 +773,16 @@ def tensor_func(): value.const_value = ir_tensor self.model.graph.register_initializer(value) - def make_node(self, op_type, inputs: Sequence[str], outputs: Sequence[str], *, name: str, domain="", **kwargs): + def make_node( + self, + op_type, + inputs: Sequence[str], + outputs: Sequence[str], + *, + name: str, + domain="", + **kwargs, + ): assert name, "Node name must be provided" if name in self.node_names: # Note: @@ -642,11 +804,23 @@ def make_node(self, op_type, inputs: Sequence[str], outputs: Sequence[str], *, n # Resolve values from names input_values = [self.make_value(name) for name in inputs] output_values = [self.make_value(name) for name in outputs] - node = ir.node(op_type, inputs=input_values, attributes=kwargs, domain=domain, outputs=output_values, name=name) + node = ir.node( + op_type, + inputs=input_values, + attributes=kwargs, + domain=domain, + outputs=output_values, + name=name, + ) self.model.graph.append(node) self.node_names.add(name) - def make_value(self, name, dtype: ir.DataType | int| None = None, shape: Sequence[int | str] | ir.Shape | None = None) -> ir.Value: + def make_value( + self, + name, + dtype: ir.DataType | int | None = None, + shape: Sequence[int | str] | ir.Shape | None = None, + ) -> ir.Value: """Obtain or create an IR value by value name. If the value does not exist a new one is created. @@ -686,11 +860,23 @@ def make_inputs_and_outputs(self): # Add KV cache to inputs key_name = f"past_key_values.{i}.key" key_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.key"]) - inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=key_shape)) + inputs.append( + self.make_value( + key_name, + dtype=self.input_types["past_key_values.key"], + shape=key_shape, + ) + ) value_name = f"past_key_values.{i}.value" value_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.value"]) - inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=value_shape)) + inputs.append( + self.make_value( + value_name, + dtype=self.input_types["past_key_values.value"], + shape=value_shape, + ) + ) # Add KV cache to outputs key_name = f"present.{i}.key" @@ -699,7 +885,13 @@ def make_inputs_and_outputs(self): value_name = f"present.{i}.value" value_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.value"]) - outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=value_shape)) + outputs.append( + self.make_value( + value_name, + dtype=self.output_types["present.value"], + shape=value_shape, + ) + ) def make_constant(self, name): # Make constant ops for 0, 1, 2, 3, etc. @@ -732,7 +924,13 @@ def make_shape(self, name, root_input, shape): def make_constant_of_shape(self, name, root_input, value, dtype, shape): output = f"{name}/output_0" - self.make_node("ConstantOfShape", inputs=[root_input], outputs=[output], name=name, value=value) + self.make_node( + "ConstantOfShape", + inputs=[root_input], + outputs=[output], + name=name, + value=value, + ) self.make_value(output, dtype, shape=shape) def make_unsqueeze(self, name, inputs, dtype, shape): @@ -879,7 +1077,11 @@ def make_matmul(self, matmul, basename, root_input, **kwargs): return self.make_matmul_op(matmul, basename, root_input, **kwargs) def make_matmul_op(self, matmul, basename, root_input, **kwargs): - if self.onnx_dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16, ir.DataType.FLOAT}: + if self.onnx_dtype in { + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.FLOAT, + }: return self.make_matmul_float(matmul, basename, root_input, **kwargs) elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: if self.quant_attrs["use_qdq"]: @@ -896,7 +1098,7 @@ def make_matmul_float(self, matmul, name, root_input, **kwargs): last_dim = matmul.weight.shape[0] output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" self.make_node("MatMul", inputs=[root_input, weight], outputs=[output], name=name) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', last_dim]) + self.make_value(output, self.io_dtype, shape=["batch_size", "sequence_length", last_dim]) return name @@ -929,11 +1131,22 @@ def make_matmul_int4(self, matmul, basename, root_input, **kwargs): output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" self.make_node( - "MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", + "MatMulNBits", + inputs=inputs, + outputs=[output], + name=name, + domain="com.microsoft", accuracy_level=self.quant_attrs["int4"]["accuracy_level"], - bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features, + bits=matmul.bits, + block_size=matmul.group_size, + K=matmul.in_features, + N=matmul.out_features, + ) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", matmul.out_features], ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features]) return name @@ -964,16 +1177,28 @@ def make_dequantize_linear(self, dequantize_name, quantized_op): if hasattr(quantized_op, "qzeros") and quantized_op.qzeros is not None: zeros = dequantize_name[1:].replace("/", ".") + ".qzeros" self.make_initializer( - ir.PackedTensor( - quantized_op.qzeros, self.onnx_dtype, shape=scales_target_shape - ), + ir.PackedTensor(quantized_op.qzeros, self.onnx_dtype, shape=scales_target_shape), zeros, ) dequantize_inputs.append(zeros) dequantize_output = f"{dequantize_name}/output_0" - self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=quantized_op.group_size, axis=-1) - self.make_value(dequantize_output, self.io_dtype, shape=[*scales_pt.shape[:-1], scales_pt.shape[-1] * quantized_op.group_size]) + self.make_node( + "DequantizeLinear", + inputs=dequantize_inputs, + outputs=[dequantize_output], + name=dequantize_name, + block_size=quantized_op.group_size, + axis=-1, + ) + self.make_value( + dequantize_output, + self.io_dtype, + shape=[ + *scales_pt.shape[:-1], + scales_pt.shape[-1] * quantized_op.group_size, + ], + ) return dequantize_output @@ -1000,8 +1225,17 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs): self.make_transpose(transpose_name, dequantize_output, self.io_dtype, transposed_shape, [1, 0]) matmul_output = "logits" if kwargs.get("logits", False) else f"{matmul_name}/output_0" - self.make_node("MatMul", inputs=[root_input, f"{transpose_name}/output_0"], outputs=[matmul_output], name=matmul_name) - self.make_value(matmul_output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features]) + self.make_node( + "MatMul", + inputs=[root_input, f"{transpose_name}/output_0"], + outputs=[matmul_output], + name=matmul_name, + ) + self.make_value( + matmul_output, + self.io_dtype, + shape=["batch_size", "sequence_length", matmul.out_features], + ) return matmul_name @@ -1046,7 +1280,11 @@ def make_matmul_lora(self, matmul, basename, root_input, **kwargs): return add_name def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs): - if self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + if self.onnx_dtype in { + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + }: return self.make_packed_matmul_float(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: return self.make_packed_matmul_int4(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) @@ -1064,7 +1302,10 @@ def make_packed_matmul_float(self, q_matmul, k_matmul, v_matmul, basename, root_ # Create dummy PackedMatMul class class PackedMatMul: def __init__(self): - self.weight = torch.cat([q_matmul.weight, k_matmul.weight, v_matmul.weight], dim=0).reshape(N_q + N_kv + N_kv, H) + self.weight = torch.cat([q_matmul.weight, k_matmul.weight, v_matmul.weight], dim=0).reshape( + N_q + N_kv + N_kv, H + ) + matmul = PackedMatMul() new_name = self.make_matmul(matmul, basename, root_input, **kwargs) @@ -1093,6 +1334,7 @@ def __init__(self): self.out_features = q_matmul.out_features + k_matmul.out_features + v_matmul.out_features self.bits = q_matmul.bits self.group_size = q_matmul.group_size + matmul = PackedMatMul() new_name = self.make_matmul_int4(matmul, basename, root_input, **kwargs) @@ -1103,7 +1345,7 @@ def make_add_bias(self, add, name, root_input, **kwargs): self.make_initializer(add, bias, to=self.io_dtype) add_bias_inputs = [root_input, bias] - shape = ['batch_size', 'sequence_length', add.shape[0]] + shape = ["batch_size", "sequence_length", add.shape[0]] if kwargs.get("logits", False): output = "logits" @@ -1125,30 +1367,67 @@ def make_embedding(self, embedding): weight_reshape_name = f"{basename}/Reshape" bits = 8 if self.int8_lm_head else 4 - weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]"] + weight_reshape_inputs = [ + f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", + f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]", + ] weight_reshape_output = f"{weight_reshape_name}/output_0" # quantized weight dtype is uint8, see here # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73 - self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=['vocab_size', 'hidden_size']) + self.make_reshape( + weight_reshape_name, + weight_reshape_inputs, + dtype=ir.DataType.UINT8, + shape=["vocab_size", "hidden_size"], + ) - self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size)) + self.make_node( + "GatherBlockQuantized", + inputs=[ + weight_reshape_output, + "input_ids", + "lm_head.MatMul.weight_scale", + "lm_head.MatMul.weight_zp", + ], + outputs=[gather_output], + name=gather_name, + domain="com.microsoft", + bits=bits, + block_size=int(self.int4_block_size), + ) else: weight = "model.embed_tokens.weight" self.make_initializer(embedding, weight, to=self.io_dtype) gather_name = f"{basename}/Gather" gather_output = f"{gather_name}/output_0" - self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name) + self.make_node( + "Gather", + inputs=[weight, "input_ids"], + outputs=[gather_output], + name=gather_name, + ) - self.make_value(gather_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_value( + gather_output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) if self.embed_attrs["scale"] != 1: # Scale the embeddings mul_name = f"{basename}/Mul" - mul_inputs = [gather_output, f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.embed_attrs['scale']}"] + mul_inputs = [ + gather_output, + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.embed_attrs['scale']}", + ] mul_output = f"{mul_name}/output_0" - self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name) - self.make_value(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_node("Mul", inputs=mul_inputs, outputs=[mul_output], name=mul_name) + self.make_value( + mul_output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) layernorm_attrs_value = mul_output else: @@ -1157,7 +1436,12 @@ def make_embedding(self, embedding): if self.layernorm_attrs["cast"]["use_fp32"] and self.io_dtype != ir.DataType.FLOAT: # Insert output Cast node cast_name = f"{basename}/Cast" - self.make_cast(cast_name, layernorm_attrs_value, ir.DataType.FLOAT, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_cast( + cast_name, + layernorm_attrs_value, + ir.DataType.FLOAT, + shape=["batch_size", "sequence_length", self.hidden_size], + ) layernorm_attrs_value = f"{cast_name}/output_0" self.layernorm_attrs["root_input"] = layernorm_attrs_value @@ -1184,7 +1468,7 @@ def make_layernorm_op(self, layer_id, layernorm, skip, simple, location): self.make_initializer( layernorm.weight + self.layernorm_attrs["add_offset"], weight, - to=new_io_dtype + to=new_io_dtype, ) bias = f"model.layers.{layer_id}.{location}_layernorm.bias" if not simple: @@ -1213,10 +1497,25 @@ def make_layernorm_op(self, layer_id, layernorm, skip, simple, location): inputs, outputs = self.make_layernorm_casts(name, inputs, outputs, old_io_dtype, new_io_dtype) # Make op and its shape - self.make_node(op_type, inputs=inputs, outputs=outputs, name=name, domain=("com.microsoft" if skip else None), **kwargs) - self.make_value(outputs[0], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_node( + op_type, + inputs=inputs, + outputs=outputs, + name=name, + domain=("com.microsoft" if skip else None), + **kwargs, + ) + self.make_value( + outputs[0], + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) if skip and not self.layernorm_attrs["last_layernorm"]: - self.make_value(outputs[3], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_value( + outputs[3], + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) # Update LayerNorm attributes self.layernorm_attrs["output_0"] = output_0 @@ -1240,13 +1539,13 @@ def _make_layernorm_op(self, layer_id, layernorm, skip, simple, location): self.make_initializer( layernorm.weight + self.layernorm_attrs["add_offset"], weight, - to=new_io_dtype + to=new_io_dtype, ) bias = f"model.layers.{layer_id}.{location}_layernorm.bias" if not simple: self.make_initializer(layernorm.bias, bias, to=new_io_dtype) - # Create input names for op + # Create input names for op inputs = [root_input, skip_input, weight] if skip else [root_input, weight] if not simple: inputs.append(bias) @@ -1271,16 +1570,46 @@ def _make_layernorm_op(self, layer_id, layernorm, skip, simple, location): skip_input = inputs[1] if skip else None if op_type == "SimplifiedLayerNormalization": - self._make_simplified_layer_norm(name, root_input, weight, outputs[0], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self._make_simplified_layer_norm( + name, + root_input, + weight, + outputs[0], + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) elif op_type == "SkipSimplifiedLayerNormalization": - self._make_skip_simplified_layer_norm(name, root_input, skip_input, weight, outputs[0], output_3, new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self._make_skip_simplified_layer_norm( + name, + root_input, + skip_input, + weight, + outputs[0], + output_3, + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) elif op_type == "SkipLayerNormalization": - self._make_skip_layer_norm(name, root_input, skip_input, weight, bias, outputs[0], output_3, new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self._make_skip_layer_norm( + name, + root_input, + skip_input, + weight, + bias, + outputs[0], + output_3, + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) else: raise ValueError(f"Invalid op_type: {op_type}") if skip and not self.layernorm_attrs["last_layernorm"]: - self.make_value(outputs[3], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_value( + outputs[3], + new_io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) # Update LayerNorm attributes self.layernorm_attrs["output_0"] = output_0 @@ -1308,7 +1637,13 @@ def make_layernorm_casts(self, name, inputs, outputs, old_dtype, new_dtype): # Cast root_input root_input_cast_name = f"{name}/root_input/Cast" root_input_cast_output = f"{root_input_cast_name}/output_0" - self.make_node("Cast", inputs=[root_input], outputs=[root_input_cast_output], name=root_input_cast_name, to=new_dtype) + self.make_node( + "Cast", + inputs=[root_input], + outputs=[root_input_cast_output], + name=root_input_cast_name, + to=new_dtype, + ) self.make_value(root_input_cast_output, new_dtype, shape=root_input_shape) inputs[0] = root_input_cast_output @@ -1317,7 +1652,13 @@ def make_layernorm_casts(self, name, inputs, outputs, old_dtype, new_dtype): assert skip_input is not None skip_input_cast_name = f"{name}/skip_input/Cast" skip_input_cast_output = f"{skip_input_cast_name}/output_0" - self.make_node("Cast", inputs=[skip_input], outputs=[skip_input_cast_output], name=skip_input_cast_name, to=new_dtype) + self.make_node( + "Cast", + inputs=[skip_input], + outputs=[skip_input_cast_output], + name=skip_input_cast_name, + to=new_dtype, + ) self.make_value(skip_input_cast_output, new_dtype, shape=self.values[skip_input].shape) inputs[1] = skip_input_cast_output @@ -1325,7 +1666,13 @@ def make_layernorm_casts(self, name, inputs, outputs, old_dtype, new_dtype): # Cast output_0 output_0_cast_name = f"{name}/output_0/Cast" output_0_cast_output = f"{output_0_cast_name}/output_0" - self.make_node("Cast", inputs=[output_0_cast_output], outputs=[output_0], name=output_0_cast_name, to=old_dtype) + self.make_node( + "Cast", + inputs=[output_0_cast_output], + outputs=[output_0], + name=output_0_cast_name, + to=old_dtype, + ) self.make_value(output_0, old_dtype, shape=root_input_shape) outputs[0] = output_0_cast_output @@ -1334,7 +1681,13 @@ def make_layernorm_casts(self, name, inputs, outputs, old_dtype, new_dtype): assert output_3 is not None output_3_cast_name = f"{name}/output_3/Cast" output_3_cast_output = f"{output_3_cast_name}/output_3" - self.make_node("Cast", inputs=[output_3_cast_output], outputs=[output_3], name=output_3_cast_name, to=old_dtype) + self.make_node( + "Cast", + inputs=[output_3_cast_output], + outputs=[output_3], + name=output_3_cast_name, + to=old_dtype, + ) self.make_value(output_3, old_dtype, shape=root_input_shape) outputs[3] = output_3_cast_output @@ -1364,7 +1717,7 @@ def make_inv_freq_rescaled(self, inv_freq): elif "ntk_alpha" in self.rope_attrs["rescale_inv_freq"]: return self.make_inv_freq_rescaled_with_ntk(inv_freq) else: - raise NotImplementedError(f"The method to rescale inv_freq could not be identified.") + raise NotImplementedError("The method to rescale inv_freq could not be identified.") def make_inv_freq_rescaled_with_freq_factors(self, inv_freq): scale_factor = self.rope_attrs["rescale_inv_freq"]["factor"] @@ -1405,9 +1758,7 @@ def make_inv_freq_rescaled_with_ntk(self, inv_freq): interpolation = 1.0 / (self.rope_attrs["rescale_inv_freq"]["factor"] * inv_freq) extrapolation = 1.0 / inv_freq - ramp = ( - torch.arange(d_half, dtype=torch.float32, device=inv_freq.device) - low - ) / (high - low) + ramp = (torch.arange(d_half, dtype=torch.float32, device=inv_freq.device) - low) / (high - low) mask = 1 - ramp.clamp(0, 1) inv_freq = interpolation * (1 - mask) + extrapolation * mask @@ -1415,16 +1766,24 @@ def make_inv_freq_rescaled_with_ntk(self, inv_freq): def make_rotary_embedding_caches_from_scratch(self): dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size) - inv_freq = 1.0 / (self.rope_attrs["rescale_factors"] * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))) + inv_freq = 1.0 / ( + self.rope_attrs["rescale_factors"] + * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + ) if "rescale_inv_freq" in self.rope_attrs: inv_freq = self.make_inv_freq_rescaled(inv_freq) position_scale = self.rope_attrs["position_scale"] if self.context_length == self.original_context_length else 1 - t = (torch.arange(self.rope_attrs["cache_length"], dtype=self.rope_attrs["t_dtype"]) * position_scale).type_as(inv_freq) + t = (torch.arange(self.rope_attrs["cache_length"], dtype=self.rope_attrs["t_dtype"]) * position_scale).type_as( + inv_freq + ) freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - cos_cache, sin_cache = emb.cos() * self.rope_attrs["mscale"], emb.sin() * self.rope_attrs["mscale"] + cos_cache, sin_cache = ( + emb.cos() * self.rope_attrs["mscale"], + emb.sin() * self.rope_attrs["mscale"], + ) return cos_cache, sin_cache def make_rotary_embedding_caches(self, **kwargs): @@ -1479,16 +1838,25 @@ def make_padded_cache(self, small_cache, large_cache, pad_value=0.0): # Create padded tensor filled with pad_value padded_cache = torch.full(target_shape, pad_value, dtype=small_cache.dtype) # Copy original data to the beginning - padded_cache[:small_cache.shape[0], :] = small_cache + padded_cache[: small_cache.shape[0], :] = small_cache return padded_cache - def _make_split_if_nodes_for_trt_rtx(self, basename, greater_name, - cos_cache_name, sin_cache_name, - cos_cache_large, sin_cache_large, - cos_cache_small, sin_cache_small, - cos_cache_large_name, sin_cache_large_name, - cos_cache_small_name, sin_cache_small_name, - small_cache_shape): + def _make_split_if_nodes_for_trt_rtx( + self, + basename, + greater_name, + cos_cache_name, + sin_cache_name, + cos_cache_large, + sin_cache_large, + cos_cache_small, + sin_cache_small, + cos_cache_large_name, + sin_cache_large_name, + cos_cache_small_name, + sin_cache_small_name, + small_cache_shape, + ): """Create split If nodes for TRT-RTX to workaround trt-rtx multi-output bug. This is a TEMPORARY workaround for TRT-RTX bug where If nodes with @@ -1503,19 +1871,38 @@ def _make_split_if_nodes_for_trt_rtx(self, basename, greater_name, cos_if_name = f"{basename}/cos/If" cos_large_for_split = ir.node( - "Constant", [], outputs=[ - ir.Value(name=f"{cos_cache_large_name}_split", type=ir.TensorType(self.io_dtype), shape=ir.Shape(cos_cache_large.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=f"{cos_cache_large_name}_split", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(cos_cache_large.shape), + ) ], - name=f"/large/cos_cache/Constant_split_cos", attributes=dict(value=ir.tensor(cos_cache_large))) + name="/large/cos_cache/Constant_split_cos", + attributes=dict(value=ir.tensor(cos_cache_large)), + ) cos_small_for_split = ir.node( - "Constant", [], outputs=[ - ir.Value(name=f"{cos_cache_small_name}_split", type=ir.TensorType(self.io_dtype), shape=ir.Shape(small_cache_shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=f"{cos_cache_small_name}_split", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(small_cache_shape), + ) ], - name=f"/small/cos_cache/Constant_split_cos", attributes=dict(value=ir.tensor(cos_cache_small))) + name="/small/cos_cache/Constant_split_cos", + attributes=dict(value=ir.tensor(cos_cache_small)), + ) self.make_node( - "If", inputs=[f"{greater_name}/output_0"], outputs=[cos_cache_name], name=cos_if_name, + "If", + inputs=[f"{greater_name}/output_0"], + outputs=[cos_cache_name], + name=cos_if_name, then_branch=ir.Graph( inputs=[], outputs=[cos_large_for_split.outputs[0]], @@ -1535,19 +1922,38 @@ def _make_split_if_nodes_for_trt_rtx(self, basename, greater_name, # Create unique constant nodes for sin to avoid tensor sharing sin_large_for_split = ir.node( - "Constant", [], outputs=[ - ir.Value(name=f"{sin_cache_large_name}_split", type=ir.TensorType(self.io_dtype), shape=ir.Shape(sin_cache_large.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=f"{sin_cache_large_name}_split", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(sin_cache_large.shape), + ) ], - name=f"/large/sin_cache/Constant_split_sin", attributes=dict(value=ir.tensor(sin_cache_large))) + name="/large/sin_cache/Constant_split_sin", + attributes=dict(value=ir.tensor(sin_cache_large)), + ) sin_small_for_split = ir.node( - "Constant", [], outputs=[ - ir.Value(name=f"{sin_cache_small_name}_split", type=ir.TensorType(self.io_dtype), shape=ir.Shape(small_cache_shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=f"{sin_cache_small_name}_split", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(small_cache_shape), + ) ], - name=f"/small/sin_cache/Constant_split_sin", attributes=dict(value=ir.tensor(sin_cache_small))) + name="/small/sin_cache/Constant_split_sin", + attributes=dict(value=ir.tensor(sin_cache_small)), + ) self.make_node( - "If", inputs=[f"{greater_name}/output_0"], outputs=[sin_cache_name], name=sin_if_name, + "If", + inputs=[f"{greater_name}/output_0"], + outputs=[sin_cache_name], + name=sin_if_name, then_branch=ir.Graph( inputs=[], outputs=[sin_large_for_split.outputs[0]], @@ -1570,14 +1976,30 @@ def make_rotary_embedding(self, name, root_input, **kwargs): cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches() num_heads = self.num_kv_heads if "k_rotary" in name else self.num_attn_heads - inputs = [root_input, kwargs.pop("position_ids"), cos_cache_name, sin_cache_name] + inputs = [ + root_input, + kwargs.pop("position_ids"), + cos_cache_name, + sin_cache_name, + ] output = f"{name}/output_0" self.make_node( - "RotaryEmbedding", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", - interleaved=self.rope_attrs["interleaved"], num_heads=(0 if self.rope_attrs["partial_rotary_factor"] == 1.0 else num_heads), # default is 0 in RotaryEmbedding kernel + "RotaryEmbedding", + inputs=inputs, + outputs=[output], + name=name, + domain="com.microsoft", + interleaved=self.rope_attrs["interleaved"], + num_heads=( + 0 if self.rope_attrs["partial_rotary_factor"] == 1.0 else num_heads + ), # default is 0 in RotaryEmbedding kernel rotary_embedding_dim=self.rope_attrs["rotary_embedding_dim"], ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * num_heads]) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.head_size * num_heads], + ) def make_rotary_embedding_multi_cache(self, **kwargs): cos_cache_name = kwargs.get("cos_cache_name", "cos_cache") @@ -1589,9 +2011,14 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.rope_attrs["mscale"] = self.rope_attrs["multi_cache"]["long_mscale"] # Create caches for when sequence_length > self.original_context_length - cos_cache_large_name, sin_cache_large_name = "cos_cache_large", "sin_cache_large" + cos_cache_large_name, sin_cache_large_name = ( + "cos_cache_large", + "sin_cache_large", + ) self.rope_attrs["save_caches"] = False - cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches(cos_cache_name=cos_cache_large_name, sin_cache_name=sin_cache_large_name) + cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches( + cos_cache_name=cos_cache_large_name, sin_cache_name=sin_cache_large_name + ) # Set cache attributes for when sequence_length <= self.original_context_length self.rope_attrs["rescale_factors"] = self.rope_attrs["multi_cache"]["short_factor"] @@ -1600,9 +2027,14 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.rope_attrs["create_caches"] = True # Create caches for when sequence_length <= self.original_context_length - cos_cache_small_name, sin_cache_small_name = "cos_cache_small", "sin_cache_small" + cos_cache_small_name, sin_cache_small_name = ( + "cos_cache_small", + "sin_cache_small", + ) self.rope_attrs["save_caches"] = False - cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches(cos_cache_name=cos_cache_small_name, sin_cache_name=sin_cache_small_name) + cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches( + cos_cache_name=cos_cache_small_name, sin_cache_name=sin_cache_small_name + ) # Determine which EPs don't support the If operator self.eps_without_if_support = ["dml"] @@ -1634,7 +2066,10 @@ def make_rotary_embedding_multi_cache(self, **kwargs): gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" greater_name = f"{basename}/Greater" - greater_inputs = [f"{gather_name}/output_0", f"/model/constants/INT64/{self.original_context_length}"] + greater_inputs = [ + f"{gather_name}/output_0", + f"/model/constants/INT64/{self.original_context_length}", + ] self.make_greater(greater_name, greater_inputs, shape=[]) # Create split If nodes and return early @@ -1651,7 +2086,7 @@ def make_rotary_embedding_multi_cache(self, **kwargs): sin_cache_large_name=sin_cache_large_name, cos_cache_small_name=cos_cache_small_name, sin_cache_small_name=sin_cache_small_name, - small_cache_shape=cos_cache_large.shape + small_cache_shape=cos_cache_large.shape, ) return @@ -1670,34 +2105,72 @@ def make_rotary_embedding_multi_cache(self, **kwargs): gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" greater_name = f"{basename}/Greater" - greater_inputs = [f"{gather_name}/output_0", f"/model/constants/INT64/{self.original_context_length}"] + greater_inputs = [ + f"{gather_name}/output_0", + f"/model/constants/INT64/{self.original_context_length}", + ] self.make_greater(greater_name, greater_inputs, shape=[]) if_name = f"{basename}/If" cos_cache_large_node = ir.node( - "Constant", [], outputs=[ - ir.Value(name=cos_cache_large_name, type=ir.TensorType(self.io_dtype), shape=ir.Shape(cos_cache_large.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=cos_cache_large_name, + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(cos_cache_large.shape), + ) ], - name="/large/cos_cache/Constant", attributes=dict(value=ir.tensor(cos_cache_large))) + name="/large/cos_cache/Constant", + attributes=dict(value=ir.tensor(cos_cache_large)), + ) sin_cache_large_node = ir.node( - "Constant", [], outputs=[ - ir.Value(name=sin_cache_large_name, type=ir.TensorType(self.io_dtype), shape=ir.Shape(sin_cache_large.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=sin_cache_large_name, + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(sin_cache_large.shape), + ) ], - name="/large/sin_cache/Constant", attributes=dict(value=ir.tensor(sin_cache_large))) + name="/large/sin_cache/Constant", + attributes=dict(value=ir.tensor(sin_cache_large)), + ) cos_cache_small_node = ir.node( - "Constant", [], outputs=[ - ir.Value(name=cos_cache_small_name, type=ir.TensorType(self.io_dtype), shape=ir.Shape(cos_cache_small.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=cos_cache_small_name, + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(cos_cache_small.shape), + ) ], - name="/small/cos_cache/Constant", attributes=dict(value=ir.tensor(cos_cache_small))) + name="/small/cos_cache/Constant", + attributes=dict(value=ir.tensor(cos_cache_small)), + ) sin_cache_small_node = ir.node( - "Constant", [], outputs=[ - ir.Value(name=sin_cache_small_name, type=ir.TensorType(self.io_dtype), shape=ir.Shape(sin_cache_small.shape)) + "Constant", + [], + outputs=[ + ir.Value( + name=sin_cache_small_name, + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(sin_cache_small.shape), + ) ], - name="/small/sin_cache/Constant", attributes=dict(value=ir.tensor(sin_cache_small))) + name="/small/sin_cache/Constant", + attributes=dict(value=ir.tensor(sin_cache_small)), + ) # Create single If node with multiple outputs self.make_node( - "If", inputs=[f"{greater_name}/output_0"], outputs=[cos_cache_name, sin_cache_name], name=if_name, + "If", + inputs=[f"{greater_name}/output_0"], + outputs=[cos_cache_name, sin_cache_name], + name=if_name, then_branch=ir.Graph( inputs=[], outputs=[ @@ -1727,7 +2200,17 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.make_value(sin_cache_name, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"]) # This expansion of contrib-op can be updated / deprecated in future. - def _make_skip_simplified_layer_norm(self, basename, root_input, skip_input, weight_name, output_0, output_3, io_dtype, shape): + def _make_skip_simplified_layer_norm( + self, + basename, + root_input, + skip_input, + weight_name, + output_0, + output_3, + io_dtype, + shape, + ): # root_input skip_input # | | # +------------------+ @@ -1737,14 +2220,41 @@ def _make_skip_simplified_layer_norm(self, basename, root_input, skip_input, wei # SimplifiedLayerNorm----> output (0) make_add_name = f"{basename}/Add" output_3 = f"{basename}/Add/output_0" if output_3 is None else output_3 - self.make_node("Add", inputs=[root_input, skip_input], outputs=[output_3], name=make_add_name) - self.make_value(output_3, io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_node( + "Add", + inputs=[root_input, skip_input], + outputs=[output_3], + name=make_add_name, + ) + self.make_value( + output_3, + io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) make_simplified_layer_norm_name = f"{basename}/skip_simplified_layer_norm" - self._make_simplified_layer_norm(make_simplified_layer_norm_name, output_3, weight_name, output_0, io_dtype, shape=shape) + self._make_simplified_layer_norm( + make_simplified_layer_norm_name, + output_3, + weight_name, + output_0, + io_dtype, + shape=shape, + ) # This expansion contrib-op can be updated / deprecated in the future. - def _make_skip_layer_norm(self, basename, root_input, skip_input, weight_name, bias_name, output_0, output_3, io_dtype, shape): + def _make_skip_layer_norm( + self, + basename, + root_input, + skip_input, + weight_name, + bias_name, + output_0, + output_3, + io_dtype, + shape, + ): # root_input skip_input # | | # +------------------+ @@ -1754,8 +2264,17 @@ def _make_skip_layer_norm(self, basename, root_input, skip_input, weight_name, b # LayerNormalization-----> output (0) output_3 = f"{basename}/Add/output_0" if output_3 is None else output_3 make_add_name = f"{basename}/Add" - self.make_node("Add", inputs=[root_input, skip_input], outputs=[output_3], name=make_add_name) - self.make_value(output_3, io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_node( + "Add", + inputs=[root_input, skip_input], + outputs=[output_3], + name=make_add_name, + ) + self.make_value( + output_3, + io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) make_layer_norm_name = f"{basename}/LayerNormalization" inputs = [output_3, weight_name, bias_name] @@ -1763,12 +2282,17 @@ def _make_skip_layer_norm(self, basename, root_input, skip_input, weight_name, b kwargs = {"epsilon": self.layernorm_attrs["epsilon"]} kwargs.update({"axis": -1, "stash_type": 1}) - self.make_node("LayerNormalization", inputs=inputs, outputs=[output_0], name=make_layer_norm_name, **kwargs) + self.make_node( + "LayerNormalization", + inputs=inputs, + outputs=[output_0], + name=make_layer_norm_name, + **kwargs, + ) self.make_value(output_0, io_dtype, shape=shape) # This expansion contrib-op can be updated / deprecated in the future. def _make_simplified_layer_norm(self, basename, root_input, weight_name, output_0, io_dtype, shape): - # Cast (float32) - most calc happens in higher precision # | # +-------+-------+ @@ -1797,15 +2321,33 @@ def _make_simplified_layer_norm(self, basename, root_input, weight_name, output_ make_pow_name = f"{basename}/Pow" make_pow_inputs = [f"{make_cast_name}/output_0", "/model/constants/FLOAT/2"] - self.make_node("Pow", inputs=make_pow_inputs, outputs=[f"{make_pow_name}/output_0"], name=make_pow_name, domain="") + self.make_node( + "Pow", + inputs=make_pow_inputs, + outputs=[f"{make_pow_name}/output_0"], + name=make_pow_name, + domain="", + ) self.make_value(f"{make_pow_name}/output_0", ir.DataType.FLOAT, shape=shape) make_reducemean_name = f"{basename}/ReduceMean" - make_reducemean_inputs = [f"{make_pow_name}/output_0", "/model/constants/INT64/[-1]"] - self.make_reduce_mean(make_reducemean_name, make_reducemean_inputs, ir.DataType.FLOAT, keepdims=True, shape=shape) + make_reducemean_inputs = [ + f"{make_pow_name}/output_0", + "/model/constants/INT64/[-1]", + ] + self.make_reduce_mean( + make_reducemean_name, + make_reducemean_inputs, + ir.DataType.FLOAT, + keepdims=True, + shape=shape, + ) make_add_name = f"{basename}/Add" - make_add_inputs = [f"{make_reducemean_name}/output_0", f"/model/constants/FLOAT/{self.layernorm_attrs['epsilon']}"] + make_add_inputs = [ + f"{make_reducemean_name}/output_0", + f"/model/constants/FLOAT/{self.layernorm_attrs['epsilon']}", + ] self.make_add(make_add_name, make_add_inputs, ir.DataType.FLOAT, shape=shape) make_sqrt_name = f"{basename}/Sqrt" @@ -1829,7 +2371,6 @@ def _make_simplified_layer_norm(self, basename, root_input, weight_name, output_ self.make_node("Mul", inputs=make_mul_1_inputs, outputs=[output_0], name=make_mul_1_name) self.make_value(output_0, dtype=io_dtype, shape=shape) - def make_qk_norm(self, layer_id, attention): # Make subgraph to compute SimplifiedLayerNorm after Q and K MatMuls in attention: # @@ -1842,16 +2383,32 @@ def make_qk_norm(self, layer_id, attention): # Reshape (BxSxD) # Save kwargs shared by LayerNorm ops and precision types to use - layernorm_kwargs = {"epsilon": self.layernorm_attrs["epsilon"], "axis": -1, "stash_type": 1} + layernorm_kwargs = { + "epsilon": self.layernorm_attrs["epsilon"], + "axis": -1, + "stash_type": 1, + } old_io_dtype = self.io_dtype new_io_dtype = ir.DataType.FLOAT if self.layernorm_attrs["cast"]["use_fp32"] else self.io_dtype cast = old_io_dtype != new_io_dtype # Reshape Q MatMul from BxSxD to Bx(SxN)xH before LayerNorm q_reshape_1_name = f"/model/layers.{layer_id}/attn/q_norm/Reshape_1" - q_reshape_1_inputs = [self.attention_attrs["q_path"], f"/model/constants/INT64/[0, -1, {self.head_size}]"] + q_reshape_1_inputs = [ + self.attention_attrs["q_path"], + f"/model/constants/INT64/[0, -1, {self.head_size}]", + ] q_reshape_1_output = f"{q_reshape_1_name}/output_0" - self.make_reshape(q_reshape_1_name, q_reshape_1_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length * num_attention_heads', self.head_size]) + self.make_reshape( + q_reshape_1_name, + q_reshape_1_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + "sequence_length * num_attention_heads", + self.head_size, + ], + ) # Make Q LayerNorm q_layernorm_name = f"/model/layers.{layer_id}/attn/q_norm/SimplifiedLayerNormalization" @@ -1860,28 +2417,72 @@ def make_qk_norm(self, layer_id, attention): self.make_initializer( attention.q_norm.weight + self.layernorm_attrs["add_offset"], q_weight_name, - to=new_io_dtype + to=new_io_dtype, ) # Create Cast nodes for inputs and outputs if old_dtype != new_dtype q_layernorm_inputs = [q_reshape_1_output, q_weight_name] q_layernorm_outputs = [q_layernorm_output] if cast: - q_layernorm_inputs, q_layernorm_outputs = self.make_layernorm_casts(q_layernorm_name, q_layernorm_inputs, q_layernorm_outputs, old_io_dtype, new_io_dtype) + q_layernorm_inputs, q_layernorm_outputs = self.make_layernorm_casts( + q_layernorm_name, + q_layernorm_inputs, + q_layernorm_outputs, + old_io_dtype, + new_io_dtype, + ) - self.make_node("SimplifiedLayerNormalization", inputs=q_layernorm_inputs, outputs=q_layernorm_outputs, name=q_layernorm_name, **layernorm_kwargs) - self.make_value(q_layernorm_outputs[0], dtype=new_io_dtype, shape=['batch_size', 'sequence_length * num_attention_heads', self.head_size]) + self.make_node( + "SimplifiedLayerNormalization", + inputs=q_layernorm_inputs, + outputs=q_layernorm_outputs, + name=q_layernorm_name, + **layernorm_kwargs, + ) + self.make_value( + q_layernorm_outputs[0], + dtype=new_io_dtype, + shape=[ + "batch_size", + "sequence_length * num_attention_heads", + self.head_size, + ], + ) # Reshape Q path after LayerNorm from Bx(SxN)xH to BxSxD q_reshape_2_name = f"/model/layers.{layer_id}/attn/q_norm/Reshape_2" - q_reshape_2_inputs = [q_layernorm_output, f"/model/constants/INT64/[0, -1, {self.num_attn_heads * self.head_size}]"] - self.make_reshape(q_reshape_2_name, q_reshape_2_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.num_attn_heads * self.head_size]) + q_reshape_2_inputs = [ + q_layernorm_output, + f"/model/constants/INT64/[0, -1, {self.num_attn_heads * self.head_size}]", + ] + self.make_reshape( + q_reshape_2_name, + q_reshape_2_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + "sequence_length", + self.num_attn_heads * self.head_size, + ], + ) # Reshape K MatMul from BxSxD to Bx(SxN)xH before LayerNorm k_reshape_1_name = f"/model/layers.{layer_id}/attn/k_norm/Reshape_1" - k_reshape_1_inputs = [self.attention_attrs["k_path"], f"/model/constants/INT64/[0, -1, {self.head_size}]"] + k_reshape_1_inputs = [ + self.attention_attrs["k_path"], + f"/model/constants/INT64/[0, -1, {self.head_size}]", + ] k_reshape_1_output = f"{k_reshape_1_name}/output_0" - self.make_reshape(k_reshape_1_name, k_reshape_1_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length * num_key_value_heads', self.head_size]) + self.make_reshape( + k_reshape_1_name, + k_reshape_1_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + "sequence_length * num_key_value_heads", + self.head_size, + ], + ) # Make K LayerNorm k_layernorm_name = f"/model/layers.{layer_id}/attn/k_norm/SimplifiedLayerNormalization" @@ -1890,22 +2491,50 @@ def make_qk_norm(self, layer_id, attention): self.make_initializer( attention.k_norm.weight + self.layernorm_attrs["add_offset"], k_weight_name, - to=new_io_dtype + to=new_io_dtype, ) # Create Cast nodes for inputs and outputs if old_dtype != new_dtype k_layernorm_inputs = [k_reshape_1_output, k_weight_name] k_layernorm_outputs = [k_layernorm_output] if cast: - k_layernorm_inputs, k_layernorm_outputs = self.make_layernorm_casts(k_layernorm_name, k_layernorm_inputs, k_layernorm_outputs, old_io_dtype, new_io_dtype) + k_layernorm_inputs, k_layernorm_outputs = self.make_layernorm_casts( + k_layernorm_name, + k_layernorm_inputs, + k_layernorm_outputs, + old_io_dtype, + new_io_dtype, + ) - self.make_node("SimplifiedLayerNormalization", inputs=k_layernorm_inputs, outputs=k_layernorm_outputs, name=k_layernorm_name, **layernorm_kwargs) - self.make_value(k_layernorm_outputs[0], dtype=new_io_dtype, shape=['batch_size', 'sequence_length * num_key_value_heads', self.head_size]) + self.make_node( + "SimplifiedLayerNormalization", + inputs=k_layernorm_inputs, + outputs=k_layernorm_outputs, + name=k_layernorm_name, + **layernorm_kwargs, + ) + self.make_value( + k_layernorm_outputs[0], + dtype=new_io_dtype, + shape=[ + "batch_size", + "sequence_length * num_key_value_heads", + self.head_size, + ], + ) # Reshape K path after LayerNorm from Bx(SxN)xH to BxSxD k_reshape_2_name = f"/model/layers.{layer_id}/attn/k_norm/Reshape_2" - k_reshape_2_inputs = [k_layernorm_output, f"/model/constants/INT64/[0, -1, {self.num_kv_heads * self.head_size}]"] - self.make_reshape(k_reshape_2_name, k_reshape_2_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.num_kv_heads * self.head_size]) + k_reshape_2_inputs = [ + k_layernorm_output, + f"/model/constants/INT64/[0, -1, {self.num_kv_heads * self.head_size}]", + ] + self.make_reshape( + k_reshape_2_name, + k_reshape_2_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.num_kv_heads * self.head_size], + ) # Update q_path and k_path now self.attention_attrs["q_path"] = f"{q_reshape_2_name}/output_0" @@ -1994,14 +2623,34 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): # | | | # present_kv +------> Gather --> Unsqueeze -----+ reshape_1_name = f"{basename}/Reshape_1" - reshape_1_inputs = [root_input, f"/model/constants/INT64/[0, 0, {self.num_kv_heads}, -1]"] - self.make_reshape(reshape_1_name, reshape_1_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.num_kv_heads, self.head_size]) + reshape_1_inputs = [ + root_input, + f"/model/constants/INT64/[0, 0, {self.num_kv_heads}, -1]", + ] + self.make_reshape( + reshape_1_name, + reshape_1_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.num_kv_heads, self.head_size], + ) transpose_1_name = f"{basename}/Transpose_1" transpose_1_input = f"{reshape_1_name}/output_0" - self.make_transpose(transpose_1_name, transpose_1_input, dtype=self.io_dtype, shape=['batch_size', self.num_kv_heads, 'sequence_length', self.head_size], perm=[0,2,1,3]) + self.make_transpose( + transpose_1_name, + transpose_1_input, + dtype=self.io_dtype, + shape=["batch_size", self.num_kv_heads, "sequence_length", self.head_size], + perm=[0, 2, 1, 3], + ) concat_1_name = f"{basename}/Concat_1" concat_1_inputs = [past_kv, f"{transpose_1_name}/output_0"] - self.make_node("Concat", inputs=concat_1_inputs, outputs=[present_kv], name=concat_1_name, axis=2) + self.make_node( + "Concat", + inputs=concat_1_inputs, + outputs=[present_kv], + name=concat_1_name, + axis=2, + ) shape_1_name = f"{basename}/Shape_1" self.make_shape(shape_1_name, present_kv, shape=[4]) @@ -2030,14 +2679,28 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): unsqueeze_4_inputs = [f"{gather_4_name}/output_0", "/model/constants/INT64/[0]"] self.make_unsqueeze(unsqueeze_4_name, unsqueeze_4_inputs, dtype=ir.DataType.INT64, shape=[1]) concat_2_name = f"{basename}/Concat_2" - concat_2_inputs = [f"{unsqueeze_1_name}/output_0", f"{unsqueeze_2_name}/output_0", f"/model/constants/INT64/[{self.num_attn_heads // self.num_kv_heads}]", f"{unsqueeze_3_name}/output_0", f"{unsqueeze_4_name}/output_0"] + concat_2_inputs = [ + f"{unsqueeze_1_name}/output_0", + f"{unsqueeze_2_name}/output_0", + f"/model/constants/INT64/[{self.num_attn_heads // self.num_kv_heads}]", + f"{unsqueeze_3_name}/output_0", + f"{unsqueeze_4_name}/output_0", + ] self.make_concat(concat_2_name, concat_2_inputs, dtype=ir.DataType.INT64, shape=[5], axis=0) mul_1_name = f"{basename}/Mul_1" - mul_1_inputs = [f"{unsqueeze_2_name}/output_0", f"/model/constants/INT64/{self.num_attn_heads // self.num_kv_heads}"] + mul_1_inputs = [ + f"{unsqueeze_2_name}/output_0", + f"/model/constants/INT64/{self.num_attn_heads // self.num_kv_heads}", + ] self.make_mul(mul_1_name, mul_1_inputs, dtype=ir.DataType.INT64, shape=None) concat_3_name = f"{basename}/Concat_3" - concat_3_inputs = [f"{unsqueeze_1_name}/output_0", f"{mul_1_name}/output_0", f"{unsqueeze_3_name}/output_0", f"{unsqueeze_4_name}/output_0"] + concat_3_inputs = [ + f"{unsqueeze_1_name}/output_0", + f"{mul_1_name}/output_0", + f"{unsqueeze_3_name}/output_0", + f"{unsqueeze_4_name}/output_0", + ] self.make_concat(concat_3_name, concat_3_inputs, dtype=ir.DataType.INT64, shape=[4], axis=0) # Make the subgraph that follows the initial subgraph @@ -2054,7 +2717,13 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): self.make_shape(shape_2_name, f"{reshape_2_name}/output_0", shape=[1]) constant_shape_name = f"{basename}/ConstantOfShape" constant_shape_value = ir.tensor([1], dtype=ir.DataType.INT64) - self.make_constant_of_shape(constant_shape_name, f"{shape_2_name}/output_0", value=constant_shape_value, dtype=ir.DataType.INT64, shape=[5]) + self.make_constant_of_shape( + constant_shape_name, + f"{shape_2_name}/output_0", + value=constant_shape_value, + dtype=ir.DataType.INT64, + shape=[5], + ) mul_2_name = f"{basename}/Mul" mul_2_inputs = [f"{constant_shape_name}/output_0", "/model/constants/INT64/-1"] self.make_mul(mul_2_name, mul_2_inputs, dtype=ir.DataType.INT64, shape=[5]) @@ -2062,7 +2731,11 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): equal_inputs = [f"{reshape_2_name}/output_0", f"{mul_2_name}/output_0"] self.make_equal(equal_name, equal_inputs, shape=[5]) where_name = f"{basename}/Where" - where_inputs = [f"{equal_name}/output_0", f"{constant_shape_name}/output_0", f"{reshape_2_name}/output_0"] + where_inputs = [ + f"{equal_name}/output_0", + f"{constant_shape_name}/output_0", + f"{reshape_2_name}/output_0", + ] self.make_where(where_name, where_inputs, dtype=ir.DataType.INT64, shape=[5]) # Make the final nodes @@ -2072,19 +2745,74 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): # Unsqueeze --> Expand --> Reshape --> Transpose --> Reshape unsqueeze_5_name = f"{basename}/Unsqueeze_5" unsqueeze_5_inputs = [present_kv, "/model/constants/INT64/[2]"] - self.make_unsqueeze(unsqueeze_5_name, unsqueeze_5_inputs, dtype=self.io_dtype, shape=['batch_size', self.num_kv_heads, 1, 'sequence_length', self.head_size]) + self.make_unsqueeze( + unsqueeze_5_name, + unsqueeze_5_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + self.num_kv_heads, + 1, + "sequence_length", + self.head_size, + ], + ) expand_name = f"{basename}/Expand" expand_inputs = [f"{unsqueeze_5_name}/output_0", f"{where_name}/output_0"] - self.make_expand(expand_name, expand_inputs, dtype=self.io_dtype, shape=['batch_size', self.num_kv_heads, self.num_attn_heads // self.num_kv_heads, 'sequence_length', self.head_size]) + self.make_expand( + expand_name, + expand_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + self.num_kv_heads, + self.num_attn_heads // self.num_kv_heads, + "sequence_length", + self.head_size, + ], + ) reshape_3_name = f"{basename}/Reshape_3" reshape_3_inputs = [f"{expand_name}/output_0", f"{concat_3_name}/output_0"] - self.make_reshape(reshape_3_name, reshape_3_inputs, dtype=self.io_dtype, shape=['batch_size', self.num_attn_heads, 'sequence_length', self.head_size]) + self.make_reshape( + reshape_3_name, + reshape_3_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + self.num_attn_heads, + "sequence_length", + self.head_size, + ], + ) transpose_2_name = f"{basename}/Transpose_2" transpose_2_input = f"{reshape_3_name}/output_0" - self.make_transpose(transpose_2_name, transpose_2_input, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.num_attn_heads, self.head_size], perm=[0,2,1,3]) + self.make_transpose( + transpose_2_name, + transpose_2_input, + dtype=self.io_dtype, + shape=[ + "batch_size", + "sequence_length", + self.num_attn_heads, + self.head_size, + ], + perm=[0, 2, 1, 3], + ) reshape_4_name = f"{basename}/Reshape_4" - reshape_4_inputs = [f"{transpose_2_name}/output_0", f"/model/constants/INT64/[0, 0, {self.num_attn_heads * self.head_size}]"] - self.make_reshape(reshape_4_name, reshape_4_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.num_attn_heads * self.head_size]) + reshape_4_inputs = [ + f"{transpose_2_name}/output_0", + f"/model/constants/INT64/[0, 0, {self.num_attn_heads * self.head_size}]", + ] + self.make_reshape( + reshape_4_name, + reshape_4_inputs, + dtype=self.io_dtype, + shape=[ + "batch_size", + "sequence_length", + self.num_attn_heads * self.head_size, + ], + ) input_to_attention = f"{reshape_4_name}/output_0" return input_to_attention @@ -2095,33 +2823,69 @@ def make_attention_op(self, name, **kwargs): if op_type == "MultiHeadAttention": self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs) elif op_type == "GroupQueryAttention": - self.make_group_query_attention(name, seqlens_k=f"{self.mask_attrs['seqlens_k']}/output_0", total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", **kwargs) + self.make_group_query_attention( + name, + seqlens_k=f"{self.mask_attrs['seqlens_k']}/output_0", + total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", + **kwargs, + ) elif op_type == "SparseAttention": - self.make_sparse_attention(name, block_row_indices=self.mask_attrs['block_row_indices'], block_col_indices=self.mask_attrs['block_col_indices'], key_total_seq_lens=f"{self.mask_attrs['key_total_seq_lens']}/output_0", total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", **kwargs) + self.make_sparse_attention( + name, + block_row_indices=self.mask_attrs["block_row_indices"], + block_col_indices=self.mask_attrs["block_col_indices"], + key_total_seq_lens=f"{self.mask_attrs['key_total_seq_lens']}/output_0", + total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", + **kwargs, + ) else: raise NotImplementedError(f"The {op_type} op is not currently supported.") def make_multi_head_attention(self, name, **kwargs): inputs = [ - kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], kwargs.get("bias", ""), - kwargs.get("attn_mask", ""), kwargs.get("add_qk", ""), - kwargs.get("past_k", ""), kwargs.get("past_v", ""), + kwargs["q_path"], + kwargs["k_path"], + kwargs["v_path"], + kwargs.get("bias", ""), + kwargs.get("attn_mask", ""), + kwargs.get("add_qk", ""), + kwargs.get("past_k", ""), + kwargs.get("past_v", ""), ] output = f"{name}/output_0" outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] self.make_node( - "MultiHeadAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", - num_heads=self.num_attn_heads, scale=self.attention_attrs["scale"], + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=name, + domain="com.microsoft", + num_heads=self.num_attn_heads, + scale=self.attention_attrs["scale"], + ) + self.make_value( + output, + self.io_dtype, + shape=[ + "batch_size", + "sequence_length", + self.head_size * self.num_attn_heads, + ], ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads]) def make_group_query_attention(self, name, **kwargs): inputs = [ - kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], - kwargs.get("past_k", ""), kwargs.get("past_v", ""), - kwargs.get("seqlens_k", ""), kwargs.get("total_seq_len", ""), - kwargs.get("cos_cache", ""), kwargs.get("sin_cache", ""), - "", "", # position_ids, attention_bias + kwargs["q_path"], + kwargs["k_path"], + kwargs["v_path"], + kwargs.get("past_k", ""), + kwargs.get("past_v", ""), + kwargs.get("seqlens_k", ""), + kwargs.get("total_seq_len", ""), + kwargs.get("cos_cache", ""), + kwargs.get("sin_cache", ""), + "", + "", # position_ids, attention_bias ] sinks = kwargs.get("sinks", "") # TODO: add to inputs list directly once ORT 1.23 is out (one-time exception) if sinks: @@ -2130,26 +2894,57 @@ def make_group_query_attention(self, name, **kwargs): output = f"{name}/output_0" outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] self.make_node( - "GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", - num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], local_window_size=self.window_size, - softcap=self.attention_attrs["softcap"], do_rotary=self.attention_attrs["use_rope_in_attn"], rotary_interleaved=self.rope_attrs["interleaved"], + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=name, + domain="com.microsoft", + num_heads=self.num_attn_heads, + kv_num_heads=self.num_kv_heads, + scale=self.attention_attrs["scale"], + local_window_size=self.window_size, + softcap=self.attention_attrs["softcap"], + do_rotary=self.attention_attrs["use_rope_in_attn"], + rotary_interleaved=self.rope_attrs["interleaved"], + ) + self.make_value( + output, + self.io_dtype, + shape=[ + "batch_size", + "sequence_length", + self.head_size * self.num_attn_heads, + ], ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads]) def make_sparse_attention(self, name, **kwargs): inputs = [ - kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], - kwargs.get("past_k"), kwargs.get("past_v"), - kwargs.get("block_row_indices"), kwargs.get("block_col_indices"), - kwargs.get("total_seq_len"), kwargs.get("key_total_seq_lens"), - kwargs.get("cos_cache", ""), kwargs.get("sin_cache", ""), + kwargs["q_path"], + kwargs["k_path"], + kwargs["v_path"], + kwargs.get("past_k"), + kwargs.get("past_v"), + kwargs.get("block_row_indices"), + kwargs.get("block_col_indices"), + kwargs.get("total_seq_len"), + kwargs.get("key_total_seq_lens"), + kwargs.get("cos_cache", ""), + kwargs.get("sin_cache", ""), ] output = f"{name}/output_0" outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] self.make_node( - "SparseAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", - num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], sparse_block_size=self.attention_attrs["block_sparse"]["sparse_block_size"], - do_rotary=self.attention_attrs["use_rope_in_attn"], rotary_interleaved=self.rope_attrs["interleaved"], + "SparseAttention", + inputs=inputs, + outputs=outputs, + name=name, + domain="com.microsoft", + num_heads=self.num_attn_heads, + kv_num_heads=self.num_kv_heads, + scale=self.attention_attrs["scale"], + sparse_block_size=self.attention_attrs["block_sparse"]["sparse_block_size"], + do_rotary=self.attention_attrs["use_rope_in_attn"], + rotary_interleaved=self.rope_attrs["interleaved"], ) def make_attention(self, layer_id, attention, root_input, **kwargs): @@ -2189,18 +2984,28 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Unpack attention weights if needed self.make_attention_unpacked(layer_id, attention, root_input, **kwargs) - + # Get dtype used for MatMul ops q_dtype = getattr(attention.q_proj, "weight", getattr(attention.q_proj, "bits", None)) k_dtype = getattr(attention.k_proj, "weight", getattr(attention.k_proj, "bits", None)) v_dtype = getattr(attention.v_proj, "weight", getattr(attention.v_proj, "bits", None)) - qkv_dtype_equal = getattr(q_dtype, "dtype", q_dtype) == getattr(k_dtype, "dtype", k_dtype) == getattr(v_dtype, "dtype", v_dtype) + qkv_dtype_equal = ( + getattr(q_dtype, "dtype", q_dtype) + == getattr(k_dtype, "dtype", k_dtype) + == getattr(v_dtype, "dtype", v_dtype) + ) # Make MatMul nodes if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal: # Combine 3 MatMuls into 1 packed MatMul qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" - qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input) + qkv_matmul_name = self.make_packed_matmul( + attention.q_proj, + attention.k_proj, + attention.v_proj, + qkv_matmul_basename, + root_input, + ) self.attention_attrs["q_path"] = f"{qkv_matmul_name}/output_0" else: q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" @@ -2222,20 +3027,38 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal and any_bias_exists: # Combine 3 Adds into 1 packed Add qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add" - self.make_packed_add(attention.q_proj.bias, attention.k_proj.bias, attention.v_proj.bias, qkv_add_name, root_input=self.attention_attrs["q_path"]) + self.make_packed_add( + attention.q_proj.bias, + attention.k_proj.bias, + attention.v_proj.bias, + qkv_add_name, + root_input=self.attention_attrs["q_path"], + ) self.attention_attrs["q_path"] = f"{qkv_add_name}/output_0" else: if q_bias_exists: q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias(attention.q_proj.bias, q_add_name, root_input=self.attention_attrs["q_path"]) + self.make_add_bias( + attention.q_proj.bias, + q_add_name, + root_input=self.attention_attrs["q_path"], + ) self.attention_attrs["q_path"] = f"{q_add_name}/output_0" if k_bias_exists: k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias(attention.k_proj.bias, k_add_name, root_input=self.attention_attrs["k_path"]) + self.make_add_bias( + attention.k_proj.bias, + k_add_name, + root_input=self.attention_attrs["k_path"], + ) self.attention_attrs["k_path"] = f"{k_add_name}/output_0" if v_bias_exists: v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"]) + self.make_add_bias( + attention.v_proj.bias, + v_add_name, + root_input=self.attention_attrs["v_path"], + ) self.attention_attrs["v_path"] = f"{v_add_name}/output_0" # Make Q/K SimplifiedLayerNorm nodes @@ -2248,10 +3071,18 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches() else: q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding" - self.make_rotary_embedding(q_rotary_name, root_input=self.attention_attrs["q_path"], position_ids=kwargs.get("position_ids", "position_ids")) + self.make_rotary_embedding( + q_rotary_name, + root_input=self.attention_attrs["q_path"], + position_ids=kwargs.get("position_ids", "position_ids"), + ) self.attention_attrs["q_path"] = f"{q_rotary_name}/output_0" k_rotary_name = f"/model/layers.{layer_id}/attn/k_rotary/RotaryEmbedding" - self.make_rotary_embedding(k_rotary_name, root_input=self.attention_attrs["k_path"], position_ids=kwargs.get("position_ids", "position_ids")) + self.make_rotary_embedding( + k_rotary_name, + root_input=self.attention_attrs["k_path"], + position_ids=kwargs.get("position_ids", "position_ids"), + ) self.attention_attrs["k_path"] = f"{k_rotary_name}/output_0" # Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA) @@ -2260,8 +3091,18 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): present_k = f"present.{layer_id}.key" present_v = f"present.{layer_id}.value" if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": - self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) - self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) + self.attention_attrs["k_path"] = self.make_repeat_kv( + layer_id, + root_input=self.attention_attrs["k_path"], + past_kv=past_k, + present_kv=present_k, + ) + self.attention_attrs["v_path"] = self.make_repeat_kv( + layer_id, + root_input=self.attention_attrs["v_path"], + past_kv=past_v, + present_kv=present_v, + ) past_k, past_v, present_k, present_v = "", "", "", "" # Make sinks input @@ -2273,13 +3114,22 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.) attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" self.make_attention_op( - attn_name, q_path=self.attention_attrs["q_path"], k_path=self.attention_attrs["k_path"], v_path=self.attention_attrs["v_path"], - past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, - cos_cache=cos_cache_name, sin_cache=sin_cache_name, sinks=sinks_name, **kwargs, + attn_name, + q_path=self.attention_attrs["q_path"], + k_path=self.attention_attrs["k_path"], + v_path=self.attention_attrs["v_path"], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + cos_cache=cos_cache_name, + sin_cache=sin_cache_name, + sinks=sinks_name, + **kwargs, ) # Make MatMul node (output projection weight node) - o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' + o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense" o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" o_weight = getattr(attention, o_proj) o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") @@ -2318,31 +3168,47 @@ def make_attention_unpacked_lora(self, layer_id, attention, qkv_linear, root_inp # Create Q/K/V base layers q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :], requires_grad=False) - q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size], requires_grad=False) + q_proj.weight = torch.nn.Parameter(qkv_linear.weight[:q_size, :], requires_grad=False) + q_proj.bias = ( + None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[:q_size], requires_grad=False) + ) k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :], requires_grad=False) - k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + k_proj.bias = ( + None + if qkv_linear.bias is None + else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + ) v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :], requires_grad=False) - v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + v_proj.bias = ( + None + if qkv_linear.bias is None + else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + ) # Create Q/K/V lora_B layers lora_B = qkv_linear.lora_B.default q_lora_B = torch.nn.Linear(in_features=q_size, out_features=q_size) - q_lora_B.weight = torch.nn.Parameter(lora_B.weight[: q_size, :], requires_grad=False) - q_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[: q_size], requires_grad=False) + q_lora_B.weight = torch.nn.Parameter(lora_B.weight[:q_size, :], requires_grad=False) + q_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[:q_size], requires_grad=False) k_lora_B = torch.nn.Linear(in_features=q_size, out_features=kv_size) k_lora_B.weight = torch.nn.Parameter(lora_B.weight[q_size : q_size + kv_size, :], requires_grad=False) - k_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[q_size : q_size + kv_size], requires_grad=False) + k_lora_B.bias = ( + None + if lora_B.bias is None + else torch.nn.Parameter(lora_B.bias[q_size : q_size + kv_size], requires_grad=False) + ) v_lora_B = torch.nn.Linear(in_features=q_size, out_features=kv_size) v_lora_B.weight = torch.nn.Parameter(lora_B.weight[q_size + kv_size :, :], requires_grad=False) - v_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[q_size + kv_size :], requires_grad=False) + v_lora_B.bias = ( + None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[q_size + kv_size :], requires_grad=False) + ) # Create Q/K/V LoRA layers attention.q_proj = LoraLayer(q_proj) @@ -2365,16 +3231,28 @@ def make_attention_unpacked_regular(self, layer_id, attention, qkv_linear, root_ kv_size = self.num_kv_heads * self.head_size attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - attention.q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :], requires_grad=False) - attention.q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size], requires_grad=False) + attention.q_proj.weight = torch.nn.Parameter(qkv_linear.weight[:q_size, :], requires_grad=False) + attention.q_proj.bias = ( + None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[:q_size], requires_grad=False) + ) attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :], requires_grad=False) - attention.k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + attention.k_proj.weight = torch.nn.Parameter( + qkv_linear.weight[q_size : q_size + kv_size, :], requires_grad=False + ) + attention.k_proj.bias = ( + None + if qkv_linear.bias is None + else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size], requires_grad=False) + ) attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) attention.v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :], requires_grad=False) - attention.v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + attention.v_proj.bias = ( + None + if qkv_linear.bias is None + else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :], requires_grad=False) + ) def make_mlp(self, layer_id, mlp, root_input): # Unpack MLP weights if needed @@ -2385,7 +3263,7 @@ def make_mlp(self, layer_id, mlp, root_input): elif self.mlp_attrs["use_fc"]: self.make_mlp_fc(layer_id, mlp, root_input) else: - raise NotImplementedError(f"The MLP layer type is not set.") + raise NotImplementedError("The MLP layer type is not set.") def make_mlp_unpacked(self, layer_id, mlp, root_input): gate_up_linear = getattr(mlp, "gate_up_proj", None) or getattr(mlp, "dense_h_to_4h", None) @@ -2408,23 +3286,39 @@ def make_mlp_unpacked_lora(self, layer_id, mlp, gate_up_linear, root_input): # Create GateProj/UpProj base layers gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) - gate_proj.weight = torch.nn.Parameter(gate_up_linear.weight[ : self.intermediate_size, :], requires_grad=False) - gate_proj.bias = None if gate_up_linear.bias is None else torch.nn.Parameter(gate_up_linear.bias[: self.intermediate_size], requires_grad=False) + gate_proj.weight = torch.nn.Parameter(gate_up_linear.weight[: self.intermediate_size, :], requires_grad=False) + gate_proj.bias = ( + None + if gate_up_linear.bias is None + else torch.nn.Parameter(gate_up_linear.bias[: self.intermediate_size], requires_grad=False) + ) up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) up_proj.weight = torch.nn.Parameter(gate_up_linear.weight[self.intermediate_size :, :], requires_grad=False) - up_proj.bias = None if gate_up_linear.bias is None else torch.nn.Parameter(gate_up_linear.bias[self.intermediate_size :], requires_grad=False) + up_proj.bias = ( + None + if gate_up_linear.bias is None + else torch.nn.Parameter(gate_up_linear.bias[self.intermediate_size :], requires_grad=False) + ) # Create GateProj/UpProj lora_B layers lora_B = gate_up_linear.lora_B.default gate_proj_lora_B = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) - gate_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[ : self.intermediate_size, :], requires_grad=False) - gate_proj_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[: self.intermediate_size], requires_grad=False) + gate_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[: self.intermediate_size, :], requires_grad=False) + gate_proj_lora_B.bias = ( + None + if lora_B.bias is None + else torch.nn.Parameter(lora_B.bias[: self.intermediate_size], requires_grad=False) + ) up_proj_lora_B = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) up_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[self.intermediate_size :, :], requires_grad=False) - up_proj_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[self.intermediate_size :], requires_grad=False) + up_proj_lora_B.bias = ( + None + if lora_B.bias is None + else torch.nn.Parameter(lora_B.bias[self.intermediate_size :], requires_grad=False) + ) # Create GateProj/UpProj LoRA layers mlp.gate_proj = LoraLayer(gate_proj) @@ -2439,12 +3333,22 @@ def make_mlp_unpacked_lora(self, layer_id, mlp, gate_up_linear, root_input): def make_mlp_unpacked_regular(self, layer_id, mlp, gate_up_linear, root_input): mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) - mlp.gate_proj.weight = torch.nn.Parameter(gate_up_linear.weight[: self.intermediate_size, :], requires_grad=False) - mlp.gate_proj.bias = None if gate_up_linear.bias is None else torch.nn.Parameter(gate_up_linear.bias[: self.intermediate_size], requires_grad=False) + mlp.gate_proj.weight = torch.nn.Parameter( + gate_up_linear.weight[: self.intermediate_size, :], requires_grad=False + ) + mlp.gate_proj.bias = ( + None + if gate_up_linear.bias is None + else torch.nn.Parameter(gate_up_linear.bias[: self.intermediate_size], requires_grad=False) + ) mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) mlp.up_proj.weight = torch.nn.Parameter(gate_up_linear.weight[self.intermediate_size :, :]) - mlp.up_proj.bias = None if gate_up_linear.bias is None else torch.nn.Parameter(gate_up_linear.bias[self.intermediate_size :], requires_grad=False) + mlp.up_proj.bias = ( + None + if gate_up_linear.bias is None + else torch.nn.Parameter(gate_up_linear.bias[self.intermediate_size :], requires_grad=False) + ) def make_mlp_proj(self, layer_id, mlp, root_input): # Make nodes for the MLP subgraph @@ -2495,7 +3399,12 @@ def make_mlp_proj(self, layer_id, mlp, root_input): # Make Mul node after activation mul_name = f"/model/layers.{layer_id}/mlp/Mul" mul_inputs = [f"{act_fn_name}/output_0", f"{up_name}/output_0"] - self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_mul( + mul_name, + mul_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) # Make output MatMul node down_matmul_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" @@ -2564,16 +3473,26 @@ def make_moe_op(self, name, **kwargs): def make_base_moe_op(self, name, **kwargs): inputs = [ - kwargs["root_input"], kwargs["router_probs"], - kwargs["weight1"], kwargs.get("bias1", ""), - kwargs["weight2"], kwargs.get("bias2", ""), - kwargs.get("weight3", ""), kwargs.get("bias3", ""), + kwargs["root_input"], + kwargs["router_probs"], + kwargs["weight1"], + kwargs.get("bias1", ""), + kwargs["weight2"], + kwargs.get("bias2", ""), + kwargs.get("weight3", ""), + kwargs.get("bias3", ""), ] output = f"{name}/output_0" - extra_kwargs = {"swiglu_limit": self.moe_attrs["swiglu_limit"]} if self.moe_attrs["swiglu_limit"] is not None else {} + extra_kwargs = ( + {"swiglu_limit": self.moe_attrs["swiglu_limit"]} if self.moe_attrs["swiglu_limit"] is not None else {} + ) self.make_node( - "MoE", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", + "MoE", + inputs=inputs, + outputs=[output], + name=name, + domain="com.microsoft", activation_alpha=self.moe_attrs["activation_alpha"], activation_beta=self.moe_attrs["activation_beta"], activation_type=self.moe_attrs["activation_type"], @@ -2583,20 +3502,37 @@ def make_base_moe_op(self, name, **kwargs): use_sparse_mixer=self.moe_attrs["use_sparse_mixer"], **extra_kwargs, ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) def make_qmoe_op(self, name, **kwargs): inputs = [ - kwargs["root_input"], kwargs["router_probs"], - kwargs["weight1"], kwargs["scales1"], kwargs.get("bias1", ""), - kwargs["weight2"], kwargs["scales2"], kwargs.get("bias2", ""), - kwargs.get("weight3", ""), kwargs.get("scales3", ""), kwargs.get("bias3", ""), + kwargs["root_input"], + kwargs["router_probs"], + kwargs["weight1"], + kwargs["scales1"], + kwargs.get("bias1", ""), + kwargs["weight2"], + kwargs["scales2"], + kwargs.get("bias2", ""), + kwargs.get("weight3", ""), + kwargs.get("scales3", ""), + kwargs.get("bias3", ""), ] output = f"{name}/output_0" - extra_kwargs = {"swiglu_limit": self.moe_attrs["swiglu_limit"]} if self.moe_attrs["swiglu_limit"] is not None else {} + extra_kwargs = ( + {"swiglu_limit": self.moe_attrs["swiglu_limit"]} if self.moe_attrs["swiglu_limit"] is not None else {} + ) self.make_node( - "QMoE", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", + "QMoE", + inputs=inputs, + outputs=[output], + name=name, + domain="com.microsoft", activation_alpha=self.moe_attrs["activation_alpha"], activation_beta=self.moe_attrs["activation_beta"], activation_type=self.moe_attrs["activation_type"], @@ -2608,7 +3544,11 @@ def make_qmoe_op(self, name, **kwargs): block_size=self.moe_attrs["block_size"], **extra_kwargs, ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) def make_qmoe_weights(self, weights): dtype = torch.quint4x2 if self.moe_attrs["expert_weight_bits"] == 4 else torch.int8 @@ -2634,19 +3574,25 @@ def make_qmoe_weights(self, weights): try: import tensorrt_llm - _, qweight, scales = ( - torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.detach().cpu().contiguous(), dtype) + _, qweight, scales = torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix( + weights.detach().cpu().contiguous(), dtype ) unsuccessful = False except ImportError: - print("WARNING: TensorRT-LLM is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix().") + print( + "WARNING: TensorRT-LLM is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix()." + ) except RuntimeError as r: - print("WARNING: TensorRT-LLM failed to run torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() successfully.") + print( + "WARNING: TensorRT-LLM failed to run torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() successfully." + ) err = str(r) - print(err[ : err.find('\n1')]) # omit internal traceback inside TensorRT-LLM + print(err[: err.find("\n1")]) # omit internal traceback inside TensorRT-LLM finally: if unsuccessful: - raise RuntimeError("Failed to quantize MoE weights with TensorRT-LLM. Please ensure TensorRT-LLM installs and runs successfully in your environment.") + raise RuntimeError( + "Failed to quantize MoE weights with TensorRT-LLM. Please ensure TensorRT-LLM installs and runs successfully in your environment." + ) return qweight, scales.to(torch.float16) @@ -2680,7 +3626,11 @@ def _symmetric_blockwise_quantize(self, weights, block_size): # Avoid division by zero - set minimum scale min_scale = 1e-8 - scales = torch.where(scales < min_scale, torch.tensor(min_scale, dtype=scales.dtype, device=scales.device), scales) + scales = torch.where( + scales < min_scale, + torch.tensor(min_scale, dtype=scales.dtype, device=scales.device), + scales, + ) # Expand scales for broadcasting: [..., num_blocks, 1] scales_expanded = scales.unsqueeze(-1) @@ -2754,13 +3704,35 @@ def make_block_sparse_moe(self, layer_id, bsm, root_input): shape_name = f"{gate_ops_base}/Shape" self.make_shape(shape_name, f"{gate_name}/output_0", shape=[3]) gather_name = f"{gate_ops_base}/Gather" - self.make_gather(gather_name, [f"{shape_name}/output_0", "/model/constants/INT64/2"], dtype=ir.DataType.INT64, shape=[], axis=0) + self.make_gather( + gather_name, + [f"{shape_name}/output_0", "/model/constants/INT64/2"], + dtype=ir.DataType.INT64, + shape=[], + axis=0, + ) unsqueeze_name = f"{gate_ops_base}/Unsqueeze" - self.make_unsqueeze(unsqueeze_name, [f"{gather_name}/output_0", "/model/constants/INT64/[0]"], dtype=ir.DataType.INT64, shape=[1]) + self.make_unsqueeze( + unsqueeze_name, + [f"{gather_name}/output_0", "/model/constants/INT64/[0]"], + dtype=ir.DataType.INT64, + shape=[1], + ) concat_name = f"{gate_ops_base}/Concat" - self.make_concat(concat_name, ["/model/constants/INT64/[-1]", f"{unsqueeze_name}/output_0"], dtype=ir.DataType.INT64, shape=[2], axis=0) + self.make_concat( + concat_name, + ["/model/constants/INT64/[-1]", f"{unsqueeze_name}/output_0"], + dtype=ir.DataType.INT64, + shape=[2], + axis=0, + ) gate_reshape_name = f"{gate_ops_base}/Reshape" - self.make_reshape(gate_reshape_name, [f"{gate_name}/output_0", f"{concat_name}/output_0"], dtype=self.io_dtype, shape=['num_rows', self.moe_attrs["num_experts"]]) + self.make_reshape( + gate_reshape_name, + [f"{gate_name}/output_0", f"{concat_name}/output_0"], + dtype=self.io_dtype, + shape=["num_rows", self.moe_attrs["num_experts"]], + ) w1_list = [] w2_list = [] @@ -2809,10 +3781,15 @@ def make_moe_initializer(w_list, moe_expert_name, dtype): make_moe_initializer(w3_scale_list, moe_expert_scales_3_name, self.io_dtype) self.make_moe_op( - moe_name, root_input=root_input, router_probs=f"{gate_reshape_name}/output_0", - weight1=moe_expert_weight_1_name, scales1=moe_expert_scales_1_name, - weight2=moe_expert_weight_2_name, scales2=moe_expert_scales_2_name, - weight3=moe_expert_weight_3_name, scales3=moe_expert_scales_3_name, + moe_name, + root_input=root_input, + router_probs=f"{gate_reshape_name}/output_0", + weight1=moe_expert_weight_1_name, + scales1=moe_expert_scales_1_name, + weight2=moe_expert_weight_2_name, + scales2=moe_expert_scales_2_name, + weight3=moe_expert_weight_3_name, + scales3=moe_expert_scales_3_name, ) # Assign output 0 of previous MoE as root input to next SkipLayerNorm @@ -2828,12 +3805,27 @@ def make_activation_with_mul(self, layer_id, root_input, activation, domain): # Mul act_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}" act_output = f"{act_name}/output_0" - self.make_node(activation, inputs=[root_input], outputs=[act_output], name=act_name, domain=domain) - self.make_value(act_output, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_node( + activation, + inputs=[root_input], + outputs=[act_output], + name=act_name, + domain=domain, + ) + self.make_value( + act_output, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) mul_act_name = f"/model/layers.{layer_id}/mlp/act_fn/Mul" mul_act_inputs = [root_input, act_output] - self.make_mul(mul_act_name, mul_act_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_mul( + mul_act_name, + mul_act_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) return mul_act_name @@ -2847,13 +3839,35 @@ def make_gelu(self, layer_id, root_input, activation): output = f"{gelu_name}/output_0" if activation == "Gelu": - self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="none") + self.make_node( + "Gelu", + inputs=[root_input], + outputs=[output], + name=gelu_name, + approximate="none", + ) elif activation == "FastGelu": - self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="tanh") + self.make_node( + "Gelu", + inputs=[root_input], + outputs=[output], + name=gelu_name, + approximate="tanh", + ) else: - self.make_node(activation, inputs=[root_input], outputs=[output], name=gelu_name, domain="com.microsoft") + self.make_node( + activation, + inputs=[root_input], + outputs=[output], + name=gelu_name, + domain="com.microsoft", + ) - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size]) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) return gelu_name @@ -2861,7 +3875,11 @@ def make_relu(self, layer_id, root_input, activation): relu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}" output = f"{relu_name}/output_0" self.make_node(activation, inputs=[root_input], outputs=[output], name=relu_name, domain="") - self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size]) + self.make_value( + output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) return relu_name def make_relu_squared(self, layer_id, root_input, activation): @@ -2869,8 +3887,18 @@ def make_relu_squared(self, layer_id, root_input, activation): basename = f"/model/layers.{layer_id}/mlp/square/{activation}" pow_name = f"{basename}/pow" pow_inputs = [f"{relu_name}/output_0", "/model/constants/INT32/[2]"] - self.make_node("Pow", inputs=pow_inputs, outputs=[f"{pow_name}/output_0"], name=pow_name, domain="") - self.make_value(f"{pow_name}/output_0", self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size]) + self.make_node( + "Pow", + inputs=pow_inputs, + outputs=[f"{pow_name}/output_0"], + name=pow_name, + domain="", + ) + self.make_value( + f"{pow_name}/output_0", + self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) return pow_name def make_activation(self, layer_id, root_input): @@ -2900,7 +3928,13 @@ def make_lm_head(self, lm_head): # List order matters here. It should match the order of the below if condition checks. # Add new checks to the end of the list and after the below if condition checks. - exists_checks = [bias_exists, scale_exists, mask_exists, softcap_exists, cast_exists] + exists_checks = [ + bias_exists, + scale_exists, + mask_exists, + softcap_exists, + cast_exists, + ] matmul_basename = "/lm_head/MatMul" root_input = self.layernorm_attrs["output_0"] @@ -2909,15 +3943,27 @@ def make_lm_head(self, lm_head): if bias_exists: add_name = "/lm_head/Add" - self.make_add_bias(lm_head.bias, add_name, root_input=f"{lm_name}/output_0", logits=not any(exists_checks[1:])) + self.make_add_bias( + lm_head.bias, + add_name, + root_input=f"{lm_name}/output_0", + logits=not any(exists_checks[1:]), + ) lm_name = add_name if scale_exists: mul_name = "/lm_head/Mul" - mul_inputs = [f"{lm_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['scale']}"] + mul_inputs = [ + f"{lm_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['scale']}", + ] mul_output = "logits" if not any(exists_checks[2:]) else f"{mul_name}/output_0" - self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name) - self.make_value(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + self.make_node("Mul", inputs=mul_inputs, outputs=[mul_output], name=mul_name) + self.make_value( + mul_output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.vocab_size], + ) lm_name = mul_name if mask_exists: @@ -2926,41 +3972,91 @@ def make_lm_head(self, lm_head): self.make_initializer(self.lm_head_attrs["mask"], logits_mask_name) where_name = "/lm_head/Where" - where_inputs = [logits_mask_name, f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{torch.finfo(to_torch_dtype(self.io_dtype)).min}", f"{lm_name}/output_0"] + where_inputs = [ + logits_mask_name, + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{torch.finfo(to_torch_dtype(self.io_dtype)).min}", + f"{lm_name}/output_0", + ] where_output = "logits" if not any(exists_checks[3:]) else f"{where_name}/output_0" - self.make_node('Where', inputs=where_inputs, outputs=[where_output], name=where_name) - self.make_value(where_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + self.make_node("Where", inputs=where_inputs, outputs=[where_output], name=where_name) + self.make_value( + where_output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.vocab_size], + ) lm_name = where_name if softcap_exists: # Add final logit softcapping (Div --> Tanh --> Mul) div_name = "/lm_head/softcap/Div" - div_inputs = [f"{lm_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['softcap']}"] - self.make_div(div_name, div_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + div_inputs = [ + f"{lm_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['softcap']}", + ] + self.make_div( + div_name, + div_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.vocab_size], + ) tanh_name = "/lm_head/softcap/Tanh" - self.make_tanh(tanh_name, f"{div_name}/output_0", dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + self.make_tanh( + tanh_name, + f"{div_name}/output_0", + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.vocab_size], + ) mul_name = "/lm_head/softcap/Mul" - mul_inputs = [f"{tanh_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['softcap']}"] + mul_inputs = [ + f"{tanh_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.lm_head_attrs['softcap']}", + ] mul_output = "logits" if not any(exists_checks[4:]) else f"{mul_name}/output_0" - self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name) - self.make_value(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + self.make_node("Mul", inputs=mul_inputs, outputs=[mul_output], name=mul_name) + self.make_value( + mul_output, + self.io_dtype, + shape=["batch_size", "sequence_length", self.vocab_size], + ) lm_name = mul_name if cast_exists: # Add final cast from io_dtype to logits_dtype cast_name = "/lm_head/Cast" cast_output = "logits" - self.make_node('Cast', inputs=[f"{lm_name}/output_0"], outputs=[cast_output], name=cast_name, to=self.output_types['logits']) - self.make_value(cast_output, self.output_types['logits'], shape=['batch_size', 'sequence_length', self.vocab_size]) + self.make_node( + "Cast", + inputs=[f"{lm_name}/output_0"], + outputs=[cast_output], + name=cast_name, + to=self.output_types["logits"], + ) + self.make_value( + cast_output, + self.output_types["logits"], + shape=["batch_size", "sequence_length", self.vocab_size], + ) def make_layer(self, layer_id, layer): # Each LLM decoder layer is typically defined as: # input_layernorm --> attention --> output_layernorm --> MLP - self.make_layernorm(layer_id, layer.input_layernorm, skip=not self.layernorm_attrs["first_layernorm"], simple=self.layernorm_attrs["simple"], location="input") + self.make_layernorm( + layer_id, + layer.input_layernorm, + skip=not self.layernorm_attrs["first_layernorm"], + simple=self.layernorm_attrs["simple"], + location="input", + ) self.make_attention(layer_id, layer.self_attn, root_input=self.layernorm_attrs["output_0"]) - self.make_layernorm(layer_id, layer.post_attention_layernorm, skip=True, simple=self.layernorm_attrs["simple"], location="post_attention") + self.make_layernorm( + layer_id, + layer.post_attention_layernorm, + skip=True, + simple=self.layernorm_attrs["simple"], + location="post_attention", + ) self.make_mlp(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) self.layernorm_attrs["first_layernorm"] = False @@ -2982,7 +4078,16 @@ def make_model(self, input_path): from gguf_model import GGUFModel except ImportError: from onnxruntime_genai.models.gguf_model import GGUFModel - model = GGUFModel.from_pretrained(self.model_type, input_path, self.head_size, self.hidden_size, self.intermediate_size, self.num_attn_heads, self.num_kv_heads, self.vocab_size) + model = GGUFModel.from_pretrained( + self.model_type, + input_path, + self.head_size, + self.hidden_size, + self.intermediate_size, + self.num_attn_heads, + self.num_kv_heads, + self.vocab_size, + ) self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models elif self.quant_type is not None: @@ -2993,20 +4098,42 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained(self.quant_type, input_path=input_path, quant_attrs=self.quant_attrs, q_size=q_size, kv_size=kv_size, intermediate_size=self.intermediate_size, num_layers=self.num_layers) + model = QuantModel.from_pretrained( + self.quant_type, + input_path=input_path, + quant_attrs=self.quant_attrs, + q_size=q_size, + kv_size=kv_size, + intermediate_size=self.intermediate_size, + num_layers=self.num_layers, + ) else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, cache_dir=self.cache_dir, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs) + model = AutoModelForCausalLM.from_pretrained( + self.model_name_or_path, + cache_dir=self.cache_dir, + token=self.hf_token, + trust_remote_code=self.hf_remote, + **extra_kwargs, + ) if "adapter_path" in self.extra_options: from peft import PeftModel - model = PeftModel.from_pretrained(model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token) + + model = PeftModel.from_pretrained( + model, + self.extra_options["adapter_path"], + cache_dir=self.cache_dir, + token=self.hf_token, + ) # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 for module in model.modules(): - if (isinstance(module, torch.nn.Embedding) and module.weight.shape[0] == self.vocab_size) or (hasattr(model, "embedding") and module == model.embedding): + if (isinstance(module, torch.nn.Embedding) and module.weight.shape[0] == self.vocab_size) or ( + hasattr(model, "embedding") and module == model.embedding + ): # Checks (Hugging Face logic) or (GGUF logic) if not self.exclude_embeds: # Embedding layer @@ -3017,7 +4144,9 @@ def make_model(self, input_path): self.layernorm_attrs["root_input"] = "inputs_embeds" self.layernorm_attrs["skip_input"] = "inputs_embeds" - elif (module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock")) and self.layer_id < self.num_layers: + elif ( + module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock") + ) and self.layer_id < self.num_layers: # Each decoder layer of model print(f"Reading decoder layer {self.layer_id}") self.make_layer(self.layer_id, module) @@ -3026,9 +4155,17 @@ def make_model(self, input_path): elif self.layer_id == self.num_layers and self.has_final_norm(module, model): # SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm) print("Reading final norm") - self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") - - elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head): + self.make_layernorm( + self.layer_id, + module, + skip=True, + simple=self.layernorm_attrs["simple"], + location="final_norm", + ) + + elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or ( + hasattr(model, "lm_head") and module == model.lm_head + ): # Checks (Hugging Face logic) or (GGUF logic) if not self.exclude_lm_head: # Language modeling head (SkipLayerNorm --> logits) @@ -3053,14 +4190,33 @@ def has_final_norm(self, module, orig_model): # hf_transformer_final_layernorm: for ChatGLM-3 # hf_language_model_norm: for Gemma-3 multimodal (4B, 12B, 27B) hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm - hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm - hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm - hf_language_model_norm = hasattr(model, "model") and hasattr(model.model, "language_model") and hasattr(model.model.language_model, "norm") and module == model.model.language_model.norm + hf_final_layernorm = ( + hasattr(model, "model") + and hasattr(model.model, "final_layernorm") + and module == model.model.final_layernorm + ) + hf_transformer_final_layernorm = ( + hasattr(model, "transformer") + and hasattr(model.transformer, "encoder") + and hasattr(model.transformer.encoder, "final_layernorm") + and module == model.transformer.encoder.final_layernorm + ) + hf_language_model_norm = ( + hasattr(model, "model") + and hasattr(model.model, "language_model") + and hasattr(model.model.language_model, "norm") + and module == model.model.language_model.norm + ) # GGUF names (all models loaded with GGUFModel.from_pretrained) gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm - hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm, hf_language_model_norm] + hf_names = [ + hf_norm, + hf_final_layernorm, + hf_transformer_final_layernorm, + hf_language_model_norm, + ] gguf_names = [gguf_final_norm] return any(hf_names + gguf_names) @@ -3068,7 +4224,11 @@ def make_preprocessing_nodes(self): self.make_attention_mask_reformatting() def make_attention_mask_reformatting(self): - if self.extra_options.get("enable_cuda_graph", False) or self.extra_options.get("enable_webgpu_graph", False) or self.ep == "dml": + if ( + self.extra_options.get("enable_cuda_graph", False) + or self.extra_options.get("enable_webgpu_graph", False) + or self.ep == "dml" + ): # ORT does not allow nodes to be placed on mulitple execution providers # with graph capture enabled. We've only verified it works with GQA and with # past_present_share_buffer enabled(so the total_seq_len in GQA is hardcoded @@ -3165,13 +4325,30 @@ def make_attention_mask_reformatting_for_mha(self): end_add_name = f"{basename}/Add" end_add_inputs = [f"{end_where_name}/output_0", f"{end_expand_name}/output_0"] - end_add_shape = ["batch_size", 1, "source_sequence_length", "target_sequence_length"] - self.make_add(end_add_name, end_add_inputs, dtype=self.io_dtype, shape=end_add_shape) # Shape of mask is now (B, 1, S, T) + end_add_shape = [ + "batch_size", + 1, + "source_sequence_length", + "target_sequence_length", + ] + self.make_add( + end_add_name, end_add_inputs, dtype=self.io_dtype, shape=end_add_shape + ) # Shape of mask is now (B, 1, S, T) tile_name = f"{basename}/Tile" - tile_inputs = [f"{end_add_name}/output_0", f"/model/constants/INT64/[1, {self.num_attn_heads}, 1, 1]"] - tile_shape = ["batch_size", self.num_attn_heads, "source_sequence_length", "target_sequence_length"] - self.make_tile(tile_name, tile_inputs, dtype=self.io_dtype, shape=tile_shape) # Shape of mask is now (B, N, S, T) + tile_inputs = [ + f"{end_add_name}/output_0", + f"/model/constants/INT64/[1, {self.num_attn_heads}, 1, 1]", + ] + tile_shape = [ + "batch_size", + self.num_attn_heads, + "source_sequence_length", + "target_sequence_length", + ] + self.make_tile( + tile_name, tile_inputs, dtype=self.io_dtype, shape=tile_shape + ) # Shape of mask is now (B, N, S, T) self.mask_attrs["mask_name"] = tile_name @@ -3195,10 +4372,16 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): # | # Unsqueeze shared_add_name = f"{basename}/Add_1" - shared_add_inputs = [f"{basename}/Gather_2/output_0", f"{past_key_gather_name}/output_0"] + shared_add_inputs = [ + f"{basename}/Gather_2/output_0", + f"{past_key_gather_name}/output_0", + ] self.make_add(shared_add_name, shared_add_inputs, dtype=ir.DataType.INT64, shape=[]) unsqueeze_3_name = f"{basename}/Unsqueeze_3" # shared unsqueeze for input_ids and past_key_values.0.key - unsqueeze_3_inputs = [f"{shared_add_name}/output_0", "/model/constants/INT64/[0]"] + unsqueeze_3_inputs = [ + f"{shared_add_name}/output_0", + "/model/constants/INT64/[0]", + ] self.make_unsqueeze(unsqueeze_3_name, unsqueeze_3_inputs, dtype=ir.DataType.INT64, shape=[1]) # Make the additional subgraph for input_ids @@ -3208,7 +4391,10 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): # Gather (idx=1) --> Concat --> ConstantOfShape Reshape --> Less --> Where --> Unsqueeze --> Unsqueeze --> Expand # \ / \ | # Unsqueeze (unsqueeze_5) Shape --> Slice --> Squeeze --> Range --> Add -------+ - unsqueeze_inputs = [f"{basename}/Gather_2/output_0", "/model/constants/INT64/[0]"] + unsqueeze_inputs = [ + f"{basename}/Gather_2/output_0", + "/model/constants/INT64/[0]", + ] unsqueeze_4_name = f"{basename}/Unsqueeze_4" self.make_unsqueeze(unsqueeze_4_name, unsqueeze_inputs, dtype=ir.DataType.INT64, shape=[1]) unsqueeze_5_name = f"{basename}/Unsqueeze_5" @@ -3220,20 +4406,40 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): self.make_concat(concat_2_name, concat_inputs, dtype=ir.DataType.INT64, shape=[2], axis=0) constant_shape_name = f"{basename}/ConstantOfShape_2" constant_shape_torch_dtype = to_torch_dtype(self.io_dtype) - constant_shape_value = ir.tensor(torch.tensor([torch.finfo(constant_shape_torch_dtype).min], dtype=constant_shape_torch_dtype), name="make_input_ids_subgraph_shape") - self.make_constant_of_shape(constant_shape_name, f"{concat_2_name}/output_0", value=constant_shape_value, dtype=self.io_dtype, shape=['unk', 'unk']) + constant_shape_value = ir.tensor( + torch.tensor( + [torch.finfo(constant_shape_torch_dtype).min], + dtype=constant_shape_torch_dtype, + ), + name="make_input_ids_subgraph_shape", + ) + self.make_constant_of_shape( + constant_shape_name, + f"{concat_2_name}/output_0", + value=constant_shape_value, + dtype=self.io_dtype, + shape=["unk", "unk"], + ) # Top path shape_4_name = f"{basename}/Shape_4" self.make_shape(shape_4_name, f"{constant_shape_name}/output_0", shape=[2]) slice_1_name = f"{basename}/Slice_1" - slice_1_inputs = [f"{shape_4_name}/output_0", "/model/constants/INT64/[-1]", f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", "/model/constants/INT64/[0]"] + slice_1_inputs = [ + f"{shape_4_name}/output_0", + "/model/constants/INT64/[-1]", + f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", + "/model/constants/INT64/[0]", + ] self.make_slice(slice_1_name, slice_1_inputs, dtype=ir.DataType.INT64, shape=[1]) squeeze_1_name = f"{basename}/Squeeze_1" squeeze_1_inputs = [f"{slice_1_name}/output_0", "/model/constants/INT64/[0]"] self.make_squeeze(squeeze_1_name, squeeze_1_inputs, dtype=ir.DataType.INT64, shape=[]) unsqueeze_7_name = f"{basename}/output_0" - unsqueeze_7_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/INT64/[0]"] + unsqueeze_7_inputs = [ + f"{squeeze_1_name}/output_0", + "/model/constants/INT64/[0]", + ] self.make_unsqueeze(unsqueeze_7_name, unsqueeze_7_inputs, dtype=ir.DataType.INT64, shape=[1]) concat_3_name = f"{basename}/Concat_3" concat_3_inputs = [f"{unsqueeze_7_name}/output_0", "/model/constants/INT64/[1]"] @@ -3243,13 +4449,22 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): shape_5_name = f"{basename}/Shape_5" self.make_shape(shape_5_name, f"{constant_shape_name}/output_0", shape=[2]) slice_2_name = f"{basename}/Slice_2" - slice_2_inputs = [f"{shape_5_name}/output_0", "/model/constants/INT64/[-1]", f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", "/model/constants/INT64/[0]"] + slice_2_inputs = [ + f"{shape_5_name}/output_0", + "/model/constants/INT64/[-1]", + f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", + "/model/constants/INT64/[0]", + ] self.make_slice(slice_2_name, slice_2_inputs, dtype=ir.DataType.INT64, shape=[1]) squeeze_2_name = f"{basename}/Squeeze_2" squeeze_2_inputs = [f"{slice_2_name}/output_0", "/model/constants/INT64/[0]"] self.make_squeeze(squeeze_2_name, squeeze_2_inputs, dtype=ir.DataType.INT64, shape=[]) range_name = f"{basename}/Range" - range_inputs = ["/model/constants/INT64/0", f"{squeeze_2_name}/output_0", "/model/constants/INT64/1"] + range_inputs = [ + "/model/constants/INT64/0", + f"{squeeze_2_name}/output_0", + "/model/constants/INT64/1", + ] self.make_range(range_name, range_inputs) add_2_name = f"{basename}/Add_2" add_inputs = [f"{range_name}/output_0", "/model/constants/INT64/1"] @@ -3263,16 +4478,29 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): less_inputs = [f"{range_name}/output_0", f"{reshape_name}/output_0"] self.make_less(less_name, less_inputs) where_2_name = f"{basename}/Where_2" - where_2_inputs = [f"{less_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/0", f"{constant_shape_name}/output_0"] + where_2_inputs = [ + f"{less_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/0", + f"{constant_shape_name}/output_0", + ] self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=None) unsqueeze_8_name = f"{basename}/Unsqueeze_8" unsqueeze_8_inputs = [f"{where_2_name}/output_0", "/model/constants/INT64/[0]"] self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=self.io_dtype, shape=None) unsqueeze_9_name = f"{basename}/Unsqueeze_9" - unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/INT64/[1]"] + unsqueeze_9_inputs = [ + f"{unsqueeze_8_name}/output_0", + "/model/constants/INT64/[1]", + ] self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) - expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) + expand_name = self.make_common_mask_reformat_subgraph( + basename, + root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", + unsqueeze_for_concat=unsqueeze_3_name, + unsqueeze_for_expand=unsqueeze_9_name, + input_ids_subgraph=True, + ) return unsqueeze_6_name, expand_name def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): @@ -3282,34 +4510,90 @@ def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): unsqueeze_3_name = f"{basename}/Unsqueeze_3" unsqueeze_3_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - attention_mask_shape.insert(1, 1) # ['batch_size', 'total_sequence_length'] --> ['batch_size', 1, 'total_sequence_length'] - self.make_unsqueeze(unsqueeze_3_name, unsqueeze_3_inputs, dtype=ir.DataType.INT64, shape=attention_mask_shape) + attention_mask_shape.insert( + 1, 1 + ) # ['batch_size', 'total_sequence_length'] --> ['batch_size', 1, 'total_sequence_length'] + self.make_unsqueeze( + unsqueeze_3_name, + unsqueeze_3_inputs, + dtype=ir.DataType.INT64, + shape=attention_mask_shape, + ) unsqueeze_4_name = f"{basename}/Unsqueeze_4" - unsqueeze_4_inputs = [f"{unsqueeze_3_name}/output_0", "/model/constants/INT64/[2]"] - attention_mask_shape.insert(1, 1) # ['batch_size', 1, 'total_sequence_length'] --> ['batch_size', 1, 1, 'total_sequence_length'] - self.make_unsqueeze(unsqueeze_4_name, unsqueeze_4_inputs, dtype=ir.DataType.INT64, shape=attention_mask_shape) + unsqueeze_4_inputs = [ + f"{unsqueeze_3_name}/output_0", + "/model/constants/INT64/[2]", + ] + attention_mask_shape.insert( + 1, 1 + ) # ['batch_size', 1, 'total_sequence_length'] --> ['batch_size', 1, 1, 'total_sequence_length'] + self.make_unsqueeze( + unsqueeze_4_name, + unsqueeze_4_inputs, + dtype=ir.DataType.INT64, + shape=attention_mask_shape, + ) # Make the main subgraph - expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="attention_mask", unsqueeze_for_concat=unsqueeze_for_concat, unsqueeze_for_expand=unsqueeze_4_name) + expand_name = self.make_common_mask_reformat_subgraph( + basename, + root_input="attention_mask", + unsqueeze_for_concat=unsqueeze_for_concat, + unsqueeze_for_expand=unsqueeze_4_name, + ) # Make the additional subgraph after Expand: # +-----------------+ # | | # Expand --> Cast --> Sub --> Cast --> Where cast_1_name = f"{basename}/Cast_1" - self.make_cast(cast_1_name, f"{expand_name}/output_0", dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) + self.make_cast( + cast_1_name, + f"{expand_name}/output_0", + dtype=self.io_dtype, + shape=["unk", "unk", "unk", "unk"], + ) sub_name = f"{basename}/Sub" - sub_inputs = [f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1", f"{cast_1_name}/output_0"] - self.make_sub(sub_name, sub_inputs, dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) + sub_inputs = [ + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1", + f"{cast_1_name}/output_0", + ] + self.make_sub( + sub_name, + sub_inputs, + dtype=self.io_dtype, + shape=["unk", "unk", "unk", "unk"], + ) cast_2_name = f"{basename}/Cast_2" - self.make_cast(cast_2_name, f"{sub_name}/output_0", dtype=ir.DataType.BOOL, shape=["unk", "unk", "unk", "unk"]) + self.make_cast( + cast_2_name, + f"{sub_name}/output_0", + dtype=ir.DataType.BOOL, + shape=["unk", "unk", "unk", "unk"], + ) where_2_name = f"{basename}/Where_2" - where_2_inputs = [f"{cast_2_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{torch.finfo(to_torch_dtype(self.io_dtype)).min}", f"{sub_name}/output_0"] - self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) + where_2_inputs = [ + f"{cast_2_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{torch.finfo(to_torch_dtype(self.io_dtype)).min}", + f"{sub_name}/output_0", + ] + self.make_where( + where_2_name, + where_2_inputs, + dtype=self.io_dtype, + shape=["unk", "unk", "unk", "unk"], + ) return where_2_name - def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for_concat, unsqueeze_for_expand, input_ids_subgraph=False): + def make_common_mask_reformat_subgraph( + self, + basename, + root_input, + unsqueeze_for_concat, + unsqueeze_for_expand, + input_ids_subgraph=False, + ): # root_input # / \ # Shape Shape @@ -3347,9 +4631,17 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for # Expand shape_1_name = f"{basename}/Shape_1" - self.make_shape(shape_1_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) + self.make_shape( + shape_1_name, + root_input, + shape=[3] if self.exclude_embeds and input_ids_subgraph else [2], + ) shape_2_name = f"{basename}/Shape_2" - self.make_shape(shape_2_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) + self.make_shape( + shape_2_name, + root_input, + shape=[3] if self.exclude_embeds and input_ids_subgraph else [2], + ) gather_1_name = f"{basename}/Gather_1" gather_1_inputs = [f"{shape_1_name}/output_0", "/model/constants/INT64/0"] self.make_gather(gather_1_name, gather_1_inputs, dtype=ir.DataType.INT64, shape=[], axis=0) @@ -3364,15 +4656,30 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for self.make_unsqueeze(unsqueeze_2_name, unsqueeze_2_inputs, dtype=ir.DataType.INT64, shape=[1]) concat_name = f"{basename}/Concat" if not input_ids_subgraph else f"{basename}/Concat_1" - concat_first_two_inputs = [f"{unsqueeze_1_name}/output_0", "/model/constants/INT64/[1]"] - concat_last_two_inputs = [f"{unsqueeze_for_concat}/output_0", f"{unsqueeze_2_name}/output_0"] if not input_ids_subgraph else [f"{unsqueeze_2_name}/output_0", f"{unsqueeze_for_concat}/output_0"] + concat_first_two_inputs = [ + f"{unsqueeze_1_name}/output_0", + "/model/constants/INT64/[1]", + ] + concat_last_two_inputs = ( + [f"{unsqueeze_for_concat}/output_0", f"{unsqueeze_2_name}/output_0"] + if not input_ids_subgraph + else [f"{unsqueeze_2_name}/output_0", f"{unsqueeze_for_concat}/output_0"] + ) concat_inputs = concat_first_two_inputs + concat_last_two_inputs self.make_concat(concat_name, concat_inputs, dtype=ir.DataType.INT64, shape=[4], axis=0) shape_3_name = f"{basename}/Shape_3" self.make_shape(shape_3_name, f"{concat_name}/output_0", shape=[1]) - constant_shape_name = f"{basename}/ConstantOfShape" if not input_ids_subgraph else f"{basename}/ConstantOfShape_1" + constant_shape_name = ( + f"{basename}/ConstantOfShape" if not input_ids_subgraph else f"{basename}/ConstantOfShape_1" + ) constant_shape_value = ir.tensor([1], dtype=ir.DataType.INT64) - self.make_constant_of_shape(constant_shape_name, f"{shape_3_name}/output_0", value=constant_shape_value, dtype=ir.DataType.INT64, shape=["unk"]) + self.make_constant_of_shape( + constant_shape_name, + f"{shape_3_name}/output_0", + value=constant_shape_value, + dtype=ir.DataType.INT64, + shape=["unk"], + ) mul_name = f"{basename}/Mul" mul_inputs = [f"{constant_shape_name}/output_0", "/model/constants/INT64/-1"] self.make_mul(mul_name, mul_inputs, dtype=ir.DataType.INT64, shape=["unk"]) @@ -3381,7 +4688,11 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for self.make_equal(equal_name, equal_inputs, shape=[4]) where_name = f"{basename}/Where_1" - where_inputs = [f"{equal_name}/output_0", f"{constant_shape_name}/output_0", f"{concat_name}/output_0"] + where_inputs = [ + f"{equal_name}/output_0", + f"{constant_shape_name}/output_0", + f"{concat_name}/output_0", + ] self.make_where(where_name, where_inputs, dtype=ir.DataType.INT64, shape=[4]) expand_name = f"{basename}/Expand" expand_inputs = [f"{unsqueeze_for_expand}/output_0", f"{where_name}/output_0"] @@ -3413,10 +4724,20 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen # Calculate ReduceSum from attention_mask cast_1_name = f"{attn_mask_basename}/Cast" - self.make_cast(cast_1_name, "attention_mask", dtype=ir.DataType.INT32, shape=["batch_size", "total_sequence_length"]) + self.make_cast( + cast_1_name, + "attention_mask", + dtype=ir.DataType.INT32, + shape=["batch_size", "total_sequence_length"], + ) reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = [f"{cast_1_name}/output_0", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_reduce_sum( + reduce_sum_name, + reduce_sum_inputs, + dtype=ir.DataType.INT32, + shape=["batch_size", 1], + ) # Left branch: Calculate seqlens_k = ReduceSum - 1 sub_name = f"{attn_mask_basename}/Sub" @@ -3449,12 +4770,22 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum( + reduce_sum_name, + reduce_sum_inputs, + dtype=ir.DataType.INT64, + shape=["batch_size", 1], + ) sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[1]"] self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) cast_1_name = f"{attn_mask_basename}/Sub/Cast" - self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast( + cast_1_name, + f"{sub_name}/output_0", + dtype=ir.DataType.INT32, + shape=["batch_size", 1], + ) # Right path shape_name = f"{attn_mask_basename}/Shape" @@ -3500,9 +4831,19 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum( + reduce_sum_name, + reduce_sum_inputs, + dtype=ir.DataType.INT64, + shape=["batch_size", 1], + ) cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast" - self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast( + cast_1_name, + f"{reduce_sum_name}/output_0", + dtype=ir.DataType.INT32, + shape=["batch_size", 1], + ) # Right path shape_name = f"{attn_mask_basename}/Shape" diff --git a/src/python/py/models/builders/phi.py b/src/python/py/models/builders/phi.py index 8aa8eae8ab..c2a77dc35a 100644 --- a/src/python/py/models/builders/phi.py +++ b/src/python/py/models/builders/phi.py @@ -3,11 +3,12 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import onnx_ir as ir +import torch + from .base import Model from .mistral import MistralModel -import onnx_ir as ir -import torch class PhiModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): @@ -18,13 +19,27 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): def make_layer(self, layer_id, layer): # Each Phi decoder layer is defined as: # input_layernorm --> attention --> MLP --> residual_add - self.make_layernorm(layer_id, layer.input_layernorm, skip=not self.layernorm_attrs["first_layernorm"], simple=self.layernorm_attrs["simple"], location="input") + self.make_layernorm( + layer_id, + layer.input_layernorm, + skip=not self.layernorm_attrs["first_layernorm"], + simple=self.layernorm_attrs["simple"], + location="input", + ) self.make_attention(layer_id, layer.self_attn, root_input=self.layernorm_attrs["output_0"]) self.make_mlp(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) residual_add_name = f"/model/layers.{layer_id}/residual_add/Add" - residual_add_inputs = [self.layernorm_attrs['skip_input'], self.mlp_attrs["output_0"]] - self.make_add(residual_add_name, residual_add_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + residual_add_inputs = [ + self.layernorm_attrs["skip_input"], + self.mlp_attrs["output_0"], + ] + self.make_add( + residual_add_name, + residual_add_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) self.layernorm_attrs["first_layernorm"] = False if layer_id == self.num_layers - 1: @@ -34,6 +49,7 @@ def make_layer(self, layer_id, layer): # Assign output 0 of residual Add as skip input to next SkipLayerNorm self.layernorm_attrs["skip_input"] = f"{residual_add_name}/output_0" + class Phi3MiniModel(MistralModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -47,7 +63,9 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Set position_ids_name based on whether position_ids is available as an input if "position_ids" in self.input_names: position_ids_result = self.make_position_ids_reformatting() - self.position_ids_name = f"{position_ids_result}/output_0" if position_ids_result != "position_ids" else "position_ids" + self.position_ids_name = ( + f"{position_ids_result}/output_0" if position_ids_result != "position_ids" else "position_ids" + ) else: # When position_ids is not an input (use_rope_in_attn is True), # position_ids won't be used since rotary embeddings are handled in GQA @@ -91,39 +109,72 @@ def make_position_ids_reformatting(self): input_tensor = "position_ids" if is_webgpu: cast_input_name = f"{basename}/Cast_input" - self.make_cast(cast_input_name, input_tensor, dtype=ir.DataType.INT32, shape=["batch_size", "sequence_length"]) + self.make_cast( + cast_input_name, + input_tensor, + dtype=ir.DataType.INT32, + shape=["batch_size", "sequence_length"], + ) input_tensor = f"{cast_input_name}/output_0" reduce_max_name = f"{basename}/ReduceMax" reduce_max_inputs = [input_tensor] self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=compute_dtype, shape=[1]) greater_or_equal_name = f"{basename}/GreaterOrEqual" - greater_or_equal_inputs = [f"{reduce_max_name}/output_0", f"/model/constants/{compute_str_dtype}/{self.original_context_length}"] + greater_or_equal_inputs = [ + f"{reduce_max_name}/output_0", + f"/model/constants/{compute_str_dtype}/{self.original_context_length}", + ] self.make_greater_or_equal(greater_or_equal_name, greater_or_equal_inputs, shape=[]) cast_name = f"{basename}/Cast" - self.make_cast(cast_name, f"{greater_or_equal_name}/output_0", dtype=compute_dtype, shape=None) + self.make_cast( + cast_name, + f"{greater_or_equal_name}/output_0", + dtype=compute_dtype, + shape=None, + ) mul_name = f"{basename}/Mul" - mul_inputs = [f"{cast_name}/output_0", f"/model/constants/{compute_str_dtype}/{self.original_context_length}"] + mul_inputs = [ + f"{cast_name}/output_0", + f"/model/constants/{compute_str_dtype}/{self.original_context_length}", + ] self.make_mul(mul_name, mul_inputs, dtype=compute_dtype, shape=None) add_1_name = f"{basename}/Add_1" add_1_inputs = [f"{mul_name}/output_0", input_tensor] - self.make_add(add_1_name, add_1_inputs, dtype=compute_dtype, shape=["batch_size", "sequence_length"]) + self.make_add( + add_1_name, + add_1_inputs, + dtype=compute_dtype, + shape=["batch_size", "sequence_length"], + ) # Cast back to int64 for WebGPU to maintain compatibility result_name = add_1_name if is_webgpu: cast_output_name = f"{basename}/Cast_output" - self.make_cast(cast_output_name, f"{add_1_name}/output_0", dtype=ir.DataType.INT64, shape=["batch_size", "sequence_length"]) + self.make_cast( + cast_output_name, + f"{add_1_name}/output_0", + dtype=ir.DataType.INT64, + shape=["batch_size", "sequence_length"], + ) result_name = cast_output_name return result_name def make_attention(self, layer_id, attention, root_input, **kwargs): if self.position_ids_name is not None: - super().make_attention(layer_id, attention, root_input, position_ids=self.position_ids_name, **kwargs) + super().make_attention( + layer_id, + attention, + root_input, + position_ids=self.position_ids_name, + **kwargs, + ) else: super().make_attention(layer_id, attention, root_input, **kwargs) + class Phi3SmallModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -157,7 +208,7 @@ def calculate_block_mask(self): q_pos = torch.arange(N_BLOCK)[:, None] k_pos = torch.arange(N_BLOCK)[None] mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 - block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)) + block_mask_dense = (q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided) N_BLOCK_Q = self.calculate_cdiv(q_len, BLOCK) block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].to_sparse_csr() @@ -170,9 +221,11 @@ def calculate_block_mask(self): q_pos = torch.arange(N_BLOCK)[None, :, None] k_pos = torch.arange(N_BLOCK)[None, None] head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads - mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] + mask_vert_strided = [ + (torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads) + ] mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) - block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)) + block_mask_dense = (q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided) N_BLOCK_Q = self.calculate_cdiv(q_len, BLOCK) block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] @@ -213,20 +266,47 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - qkv_weight = attention.query_key_value.weight.T.view(self.hidden_size, self.num_kv_heads, (self.num_attn_heads // self.num_kv_heads) + 2, self.head_size) - qkv_bias = attention.query_key_value.bias.view(self.num_kv_heads, (self.num_attn_heads // self.num_kv_heads) + 2, self.head_size) + qkv_weight = attention.query_key_value.weight.T.view( + self.hidden_size, + self.num_kv_heads, + (self.num_attn_heads // self.num_kv_heads) + 2, + self.head_size, + ) + qkv_bias = attention.query_key_value.bias.view( + self.num_kv_heads, + (self.num_attn_heads // self.num_kv_heads) + 2, + self.head_size, + ) attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - attention.q_proj.weight = torch.nn.Parameter(qkv_weight[:, :, :-2].reshape(q_size, q_size).T, requires_grad=False) - attention.q_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, :-2].flatten(), requires_grad=False) + attention.q_proj.weight = torch.nn.Parameter( + qkv_weight[:, :, :-2].reshape(q_size, q_size).T, requires_grad=False + ) + attention.q_proj.bias = ( + None + if attention.query_key_value.bias is None + else torch.nn.Parameter(qkv_bias[:, :-2].flatten(), requires_grad=False) + ) attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.k_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-2]].reshape(q_size, kv_size).T, requires_grad=False) - attention.k_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-2]].flatten(), requires_grad=False) + attention.k_proj.weight = torch.nn.Parameter( + qkv_weight[:, :, [-2]].reshape(q_size, kv_size).T, requires_grad=False + ) + attention.k_proj.bias = ( + None + if attention.query_key_value.bias is None + else torch.nn.Parameter(qkv_bias[:, [-2]].flatten(), requires_grad=False) + ) attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.v_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-1]].reshape(q_size, kv_size).T, requires_grad=False) - attention.v_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-1]].flatten(), requires_grad=False) + attention.v_proj.weight = torch.nn.Parameter( + qkv_weight[:, :, [-1]].reshape(q_size, kv_size).T, requires_grad=False + ) + attention.v_proj.bias = ( + None + if attention.query_key_value.bias is None + else torch.nn.Parameter(qkv_bias[:, [-1]].flatten(), requires_grad=False) + ) del qkv_weight del qkv_bias @@ -275,43 +355,132 @@ def make_mlp_proj(self, layer_id, mlp, root_input): # Left path slice_1_name = f"/model/layers.{layer_id}/mlp/gelu/Slice" - slice_1_inputs = [f"{up_add_name}/output_0", "/model/constants/INT64/[0]", f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", "/model/constants/INT64/[-1]", "/model/constants/INT64/[2]"] - self.make_slice(slice_1_name, slice_1_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + slice_1_inputs = [ + f"{up_add_name}/output_0", + "/model/constants/INT64/[0]", + f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", + "/model/constants/INT64/[-1]", + "/model/constants/INT64/[2]", + ] + self.make_slice( + slice_1_name, + slice_1_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) cast_1_name = f"/model/layers.{layer_id}/mlp/gelu/Cast" - self.make_cast(cast_1_name, f"{slice_1_name}/output_0", dtype=ir.DataType.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_cast( + cast_1_name, + f"{slice_1_name}/output_0", + dtype=ir.DataType.FLOAT, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) isinf_1_name = f"/model/layers.{layer_id}/mlp/gelu/IsInf" - self.make_isinf(isinf_1_name, f"{cast_1_name}/output_0", shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_isinf( + isinf_1_name, + f"{cast_1_name}/output_0", + shape=["batch_size", "sequence_length", self.intermediate_size], + ) clip_1_name = f"/model/layers.{layer_id}/mlp/gelu/Clip" - clip_1_inputs = [f"{slice_1_name}/output_0", "", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.clamp_limit}"] - self.make_clip(clip_1_name, clip_1_inputs, self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + clip_1_inputs = [ + f"{slice_1_name}/output_0", + "", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.clamp_limit}", + ] + self.make_clip( + clip_1_name, + clip_1_inputs, + self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) where_1_name = f"/model/layers.{layer_id}/mlp/gelu/Where" - where_1_inputs = [f"{isinf_1_name}/output_0", f"{slice_1_name}/output_0", f"{clip_1_name}/output_0"] - self.make_where(where_1_name, where_1_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + where_1_inputs = [ + f"{isinf_1_name}/output_0", + f"{slice_1_name}/output_0", + f"{clip_1_name}/output_0", + ] + self.make_where( + where_1_name, + where_1_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) # Make activation act_fn_name = self.make_activation(layer_id, root_input=f"{where_1_name}/output_0") # Right path slice_2_name = f"/model/layers.{layer_id}/mlp/linear/Slice" - slice_2_inputs = [f"{up_add_name}/output_0", "/model/constants/INT64/[1]", f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", "/model/constants/INT64/[-1]", "/model/constants/INT64/[2]"] - self.make_slice(slice_2_name, slice_2_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + slice_2_inputs = [ + f"{up_add_name}/output_0", + "/model/constants/INT64/[1]", + f"/model/constants/INT64/[{torch.iinfo(torch.int64).max}]", + "/model/constants/INT64/[-1]", + "/model/constants/INT64/[2]", + ] + self.make_slice( + slice_2_name, + slice_2_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) cast_2_name = f"/model/layers.{layer_id}/mlp/linear/Cast" - self.make_cast(cast_2_name, f"{slice_2_name}/output_0", dtype=ir.DataType.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_cast( + cast_2_name, + f"{slice_2_name}/output_0", + dtype=ir.DataType.FLOAT, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) isinf_2_name = f"/model/layers.{layer_id}/mlp/linear/IsInf" - self.make_isinf(isinf_2_name, f"{cast_2_name}/output_0", shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_isinf( + isinf_2_name, + f"{cast_2_name}/output_0", + shape=["batch_size", "sequence_length", self.intermediate_size], + ) clip_2_name = f"/model/layers.{layer_id}/mlp/linear/Clip" - clip_2_inputs = [f"{slice_2_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/-{self.clamp_limit}", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.clamp_limit}"] - self.make_clip(clip_2_name, clip_2_inputs, self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + clip_2_inputs = [ + f"{slice_2_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/-{self.clamp_limit}", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/{self.clamp_limit}", + ] + self.make_clip( + clip_2_name, + clip_2_inputs, + self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) where_2_name = f"/model/layers.{layer_id}/mlp/linear/Where" - where_2_inputs = [f"{isinf_2_name}/output_0", f"{slice_2_name}/output_0", f"{clip_2_name}/output_0"] - self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + where_2_inputs = [ + f"{isinf_2_name}/output_0", + f"{slice_2_name}/output_0", + f"{clip_2_name}/output_0", + ] + self.make_where( + where_2_name, + where_2_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) add_name = f"/model/layers.{layer_id}/mlp/linear/Add" - add_inputs = [f"{where_2_name}/output_0", f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1"] - self.make_add(add_name, add_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + add_inputs = [ + f"{where_2_name}/output_0", + f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1", + ] + self.make_add( + add_name, + add_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) # Make Mul node after activation mul_name = f"/model/layers.{layer_id}/mlp/Mul" mul_inputs = [f"{act_fn_name}/output_0", f"{add_name}/output_0"] - self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + self.make_mul( + mul_name, + mul_inputs, + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.intermediate_size], + ) # Make output MatMul and Add nodes down_matmul_name = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" @@ -345,16 +514,33 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): def make_layer(self, layer_id, layer): # Each LLM decoder layer is typically defined as: # input_layernorm --> attention --> output_layernorm --> MoE - self.make_layernorm(layer_id, layer.input_layernorm, skip=not self.layernorm_attrs["first_layernorm"], simple=self.layernorm_attrs["simple"], location="input") + self.make_layernorm( + layer_id, + layer.input_layernorm, + skip=not self.layernorm_attrs["first_layernorm"], + simple=self.layernorm_attrs["simple"], + location="input", + ) self.make_attention(layer_id, layer.self_attn, root_input=self.layernorm_attrs["output_0"]) - self.make_layernorm(layer_id, layer.post_attention_layernorm, skip=True, simple=self.layernorm_attrs["simple"], location="post_attention") - self.make_block_sparse_moe(layer_id, layer.block_sparse_moe, root_input=self.layernorm_attrs["output_0"]) + self.make_layernorm( + layer_id, + layer.post_attention_layernorm, + skip=True, + simple=self.layernorm_attrs["simple"], + location="post_attention", + ) + self.make_block_sparse_moe( + layer_id, + layer.block_sparse_moe, + root_input=self.layernorm_attrs["output_0"], + ) self.layernorm_attrs["first_layernorm"] = False if layer_id == self.num_layers - 1: # Norm after last decoder layer of model (last layer --> norm) self.layernorm_attrs["last_layernorm"] = True + class Phi4MMModel(Phi3VModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -376,4 +562,4 @@ def make_layer(self, layer_id, layer): layer.mlp.down_proj.lora_B.default = layer.mlp.down_proj.lora_B.vision layer.mlp.down_proj.scaling["default"] = layer.mlp.down_proj.scaling["vision"] - super().make_layer(layer_id, layer) \ No newline at end of file + super().make_layer(layer_id, layer) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 9219f2e34a..12eff67679 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -5,14 +5,18 @@ # -------------------------------------------------------------------------- import os -from .base import Model + import onnx_ir as ir import torch +from .base import Model + + class QwenModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + class Qwen3Model(QwenModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -22,11 +26,12 @@ def make_attention_init(self): self.attention_attrs["k_norm"] = True super().make_attention_init() -class Qwen25VLTextModel(QwenModel): + +class Qwen25VLTextModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # We must extract the text_config for the text model's parameters text_config_dict = config.text_config.to_dict() - + # Update the main config with text-specific parameters # The base.Model class reads from the top-level config object config.hidden_size = text_config_dict["hidden_size"] @@ -66,7 +71,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): if "rope_cast" not in self.attention_attrs: self.attention_attrs["rope_cast"] = {} self.attention_attrs["rope_cast"]["use_fp32"] = True - + # The base.Model.make_outputs_init() *always* casts logits to float32 # if the io_dtype is bfloat16. This is to improve accuracy in general. # @@ -81,7 +86,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): if self.allow_bf16_logits and self.io_dtype == ir.DataType.BFLOAT16: print("Fixing output logits precision. Setting output_types['logits'] to BFLOAT16 to match HF model.") self.output_types["logits"] = ir.DataType.BFLOAT16 - + # Manually get the attention_scaling from the rope_config # This replicates the logic from transformers.models.rope_utils._config_to_init_values rope_type = "default" @@ -89,20 +94,25 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # The config re-maps 'mrope' to 'default' if config.rope_scaling["type"] != "mrope": rope_type = config.rope_scaling["type"] - + if rope_type == "yarn": factor = config.rope_scaling.get("factor", 1.0) - self.rope_attrs["attention_scaling"] = config.rope_scaling.get("attention_factor", (0.1 * torch.log(torch.tensor(factor)) + 1.0).item()) + self.rope_attrs["attention_scaling"] = config.rope_scaling.get( + "attention_factor", (0.1 * torch.log(torch.tensor(factor)) + 1.0).item() + ) elif rope_type == "longrope": factor = config.rope_scaling.get("factor", 1.0) orig_max_pos = config.original_max_position_embeddings - self.rope_attrs["attention_scaling"] = config.rope_scaling.get("attention_factor", torch.sqrt(1 + torch.log(torch.tensor(factor)) / torch.log(torch.tensor(orig_max_pos))).item()) + self.rope_attrs["attention_scaling"] = config.rope_scaling.get( + "attention_factor", + torch.sqrt(1 + torch.log(torch.tensor(factor)) / torch.log(torch.tensor(orig_max_pos))).item(), + ) else: self.rope_attrs["attention_scaling"] = 1.0 - + # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False - + # Your inheritance change fixed this, but this check is harmless and safe. if "position_ids" not in self.input_names: print("Re-adding 'position_ids' to self.input_names.") @@ -111,27 +121,31 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.input_names.insert(idx + 1, "position_ids") else: self.input_names.append("position_ids") - + self.mrope_sections = self.rope_attrs.get("mrope", {}).get("sections", []) if not self.mrope_sections: raise ValueError("MRoPE sections not found in config.text_config.rope_scaling.mrope_section") - + # The HF logic is `mrope_section * 2`, not `[s * 2 for s in mrope_section]`. # This results in [16, 24, 24, 16, 24, 24] self.mrope_splits = self.mrope_sections * 2 - + if sum(self.mrope_splits) != self.head_size: # The sum (128) should now correctly match self.head_size (128) - raise ValueError(f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})") + raise ValueError( + f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})" + ) # Force GroupQueryAttention for fp32 cuda, # as base.py's make_attention_init doesn't include this combo. if self.ep == "cuda" and self.io_dtype == ir.DataType.FLOAT: self.attention_attrs["op_type"] = "GroupQueryAttention" print("Forcing GroupQueryAttention (GQA) for FP32 CUDA.") - + if self.attention_attrs["op_type"] != "GroupQueryAttention": - raise ValueError(f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo.") + raise ValueError( + f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo." + ) # Create and save the inv_freq tensor self.make_inv_freq_tensor() @@ -142,21 +156,25 @@ def make_inv_freq_tensor(self): This is copied from base.py:make_rotary_embedding_caches_from_scratch """ dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size) - inv_freq = 1.0 / (self.rope_attrs["rescale_factors"] * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))) - + inv_freq = 1.0 / ( + self.rope_attrs["rescale_factors"] + * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + ) + # The HF model expects H/2, not R/2 if dim != self.head_size: - print(f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported.") - inv_freq = inv_freq[:(self.head_size // 2)] - + print( + f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported." + ) + inv_freq = inv_freq[: (self.head_size // 2)] + self.make_initializer(inv_freq, "model.inv_freq", to=ir.DataType.FLOAT) print("Created and saved 'model.inv_freq' initializer.") - def make_inputs_and_outputs(self): # Qwen2.5-VL uses 3D position_ids self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"] - + # Call the base Model's make_inputs_and_outputs (skipping MistralModel's) super().make_inputs_and_outputs() @@ -166,7 +184,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): Takes 3D position_ids and inv_freq and dynamically creates the cos/sin caches. """ - pos_ids_name = "position_ids" + pos_ids_name = "position_ids" inv_freq_name = "model.inv_freq" head_dim_half = self.head_size // 2 @@ -177,85 +195,157 @@ def make_dynamic_rope_caches(self, layer_id, basename): gather_batch_size_name = f"{basename}/gather_batch_size" gather_batch_size_output = f"{gather_batch_size_name}/output_0" - self.make_gather(gather_batch_size_name, [shape_pos_ids_output, "/model/constants/INT64/[1]"], ir.DataType.INT64, [1], axis=0) - + self.make_gather( + gather_batch_size_name, + [shape_pos_ids_output, "/model/constants/INT64/[1]"], + ir.DataType.INT64, + [1], + axis=0, + ) + # Expand inv_freq: [H/2] -> [1, 1, H/2, 1] unsqueeze_1_name = f"{basename}/inv_freq_unsqueeze_1" unsqueeze_1_output = f"{unsqueeze_1_name}/output_0" - self.make_unsqueeze(unsqueeze_1_name, [inv_freq_name, "/model/constants/INT64/[0, 1, 3]"], ir.DataType.FLOAT, [1, 1, head_dim_half, 1]) - + self.make_unsqueeze( + unsqueeze_1_name, + [inv_freq_name, "/model/constants/INT64/[0, 1, 3]"], + ir.DataType.FLOAT, + [1, 1, head_dim_half, 1], + ) + # Create target shape for Expand: [3, B, H/2, 1] concat_expand_shape_name = f"{basename}/concat_expand_shape" concat_expand_shape_output = f"{concat_expand_shape_name}/output_0" self.make_concat( concat_expand_shape_name, - ["/model/constants/INT64/[3]", gather_batch_size_output, f"/model/constants/INT64/[{head_dim_half}, 1]"], + [ + "/model/constants/INT64/[3]", + gather_batch_size_output, + f"/model/constants/INT64/[{head_dim_half}, 1]", + ], ir.DataType.INT64, [4], - axis=0 + axis=0, ) - + expand_name = f"{basename}/inv_freq_expand" expand_output = f"{expand_name}/output_0" - self.make_expand(expand_name, [unsqueeze_1_output, concat_expand_shape_output], ir.DataType.FLOAT, [3, "batch_size", head_dim_half, 1]) - + self.make_expand( + expand_name, + [unsqueeze_1_output, concat_expand_shape_output], + ir.DataType.FLOAT, + [3, "batch_size", head_dim_half, 1], + ) + # Expand position_ids: [3, B, S] -> [3, B, 1, S] unsqueeze_2_name = f"{basename}/pos_ids_unsqueeze" unsqueeze_2_output = f"{unsqueeze_2_name}/output_0" - self.make_unsqueeze(unsqueeze_2_name, [pos_ids_name, "/model/constants/INT64/[2]"], ir.DataType.INT64, [3, "batch_size", 1, "sequence_length"]) - + self.make_unsqueeze( + unsqueeze_2_name, + [pos_ids_name, "/model/constants/INT64/[2]"], + ir.DataType.INT64, + [3, "batch_size", 1, "sequence_length"], + ) + # Cast position_ids to float cast_name = f"{basename}/pos_ids_cast" cast_output = f"{cast_name}/output_0" - self.make_cast(cast_name, unsqueeze_2_output, ir.DataType.FLOAT, [3, "batch_size", 1, "sequence_length"]) + self.make_cast( + cast_name, + unsqueeze_2_output, + ir.DataType.FLOAT, + [3, "batch_size", 1, "sequence_length"], + ) # MatMul: [3, B, H/2, 1] @ [3, B, 1, S] -> [3, B, H/2, S] matmul_name = f"{basename}/freqs_matmul" matmul_output = f"{matmul_name}/output_0" self.make_node("MatMul", [expand_output, cast_output], [matmul_output], name=matmul_name) - self.make_value(matmul_output, ir.DataType.FLOAT, [3, "batch_size", head_dim_half, "sequence_length"]) + self.make_value( + matmul_output, + ir.DataType.FLOAT, + [3, "batch_size", head_dim_half, "sequence_length"], + ) # Transpose: [3, B, H/2, S] -> [3, B, S, H/2] transpose_name = f"{basename}/freqs_transpose" transpose_output = f"{transpose_name}/output_0" - self.make_transpose(transpose_name, matmul_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", head_dim_half], perm=[0, 1, 3, 2]) + self.make_transpose( + transpose_name, + matmul_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", head_dim_half], + perm=[0, 1, 3, 2], + ) # Concat (freqs, freqs): [3, B, S, H/2] -> [3, B, S, H] concat_name = f"{basename}/emb_concat" concat_output = f"{concat_name}/output_0" - self.make_concat(concat_name, [transpose_output, transpose_output], ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size], axis=-1) + self.make_concat( + concat_name, + [transpose_output, transpose_output], + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + axis=-1, + ) # Cos(emb) and Sin(emb) cos_name = f"{basename}/cos" cos_output = f"{cos_name}/output_0" self.make_node("Cos", [concat_output], [cos_output], name=cos_name) - self.make_value(cos_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) - + self.make_value( + cos_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) + sin_name = f"{basename}/sin" sin_output = f"{sin_name}/output_0" self.make_node("Sin", [concat_output], [sin_output], name=sin_name) - self.make_value(sin_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + self.make_value( + sin_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) # Apply attention_scaling cos_final_output = cos_output sin_final_output = sin_output - scale = self.rope_attrs.get("attention_scaling", 1.0) # Get from rope_attrs + scale = self.rope_attrs.get("attention_scaling", 1.0) # Get from rope_attrs if scale != 1.0: scale_const_name = f"/model/constants/FLOAT/{scale}" - + cos_mul_name = f"{basename}/cos_mul_scale" cos_final_output = f"{cos_mul_name}/output_0" - self.make_node("Mul", [cos_output, scale_const_name], [cos_final_output], name=cos_mul_name) - self.make_value(cos_final_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + self.make_node( + "Mul", + [cos_output, scale_const_name], + [cos_final_output], + name=cos_mul_name, + ) + self.make_value( + cos_final_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) sin_mul_name = f"{basename}/sin_mul_scale" sin_final_output = f"{sin_mul_name}/output_0" - self.make_node("Mul", [sin_output, scale_const_name], [sin_final_output], name=sin_mul_name) - self.make_value(sin_final_output, ir.DataType.FLOAT, [3, "batch_size", "sequence_length", self.head_size]) + self.make_node( + "Mul", + [sin_output, scale_const_name], + [sin_final_output], + name=sin_mul_name, + ) + self.make_value( + sin_final_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) return cos_final_output, sin_final_output - + def rotate_half(self, x_name, x_shape, basename, compute_dtype): """ Builds ONNX nodes for rotate_half(x) @@ -265,22 +355,29 @@ def rotate_half(self, x_name, x_shape, basename, compute_dtype): split_name = f"{basename}/rotate_half/Split" split_output_0 = f"{split_name}/output_0" split_output_1 = f"{split_name}/output_1" - self.make_node("Split", [x_name], [split_output_0, split_output_1], name=split_name, axis=-1, num_outputs=2) + self.make_node( + "Split", + [x_name], + [split_output_0, split_output_1], + name=split_name, + axis=-1, + num_outputs=2, + ) half_shape = x_shape[:-1] + [x_shape[-1] // 2] self.make_value(split_output_0, compute_dtype, half_shape) self.make_value(split_output_1, compute_dtype, half_shape) - + # Negate x2 neg_name = f"{basename}/rotate_half/Neg" neg_output = f"{neg_name}/output_0" self.make_node("Neg", [split_output_1], [neg_output], name=neg_name) self.make_value(neg_output, compute_dtype, half_shape) - + # Concat (-x2, x1) concat_name = f"{basename}/rotate_half/Concat" concat_output = f"{concat_name}/output_0" self.make_concat(concat_name, [neg_output, split_output_0], compute_dtype, x_shape, axis=-1) - + return concat_output def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): @@ -289,11 +386,11 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn Takes Q/K tensor and the dynamically generated 3D caches and applies the rotation. """ - + # --- Handle precision for RoPE --- # Check if we need to force float32 computation force_fp32 = self.attention_attrs.get("rope_cast", {}).get("use_fp32", False) - + # Set compute_dtype (precision for math) and output_dtype (final precision) compute_dtype = ir.DataType.FLOAT if force_fp32 else self.io_dtype output_dtype = self.io_dtype @@ -305,91 +402,176 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn mrope_splits_output_name = f"{basename}/mrope_splits" mrope_splits_tensor = ir.tensor( torch.tensor(self.mrope_splits, dtype=torch.int64), - name=mrope_splits_output_name + name=mrope_splits_output_name, ) self.make_node( "Constant", inputs=[], outputs=[mrope_splits_output_name], name=mrope_splits_node_name, - value=mrope_splits_tensor + value=mrope_splits_tensor, ) self.make_value(mrope_splits_output_name, ir.DataType.INT64, [len(self.mrope_splits)]) - + # Split the dynamic caches [3, B, S, H] into 6 chunks on axis -1 # Caches (dyn_cos, dyn_sin) are already in float32 num_splits = len(self.mrope_splits) - + cos_split_name = f"{basename}/cos_split" cos_split_outputs = [f"{cos_split_name}/output_{i}" for i in range(num_splits)] - self.make_node("Split", [dyn_cos, mrope_splits_output_name], cos_split_outputs, name=cos_split_name, axis=-1) + self.make_node( + "Split", + [dyn_cos, mrope_splits_output_name], + cos_split_outputs, + name=cos_split_name, + axis=-1, + ) sin_split_name = f"{basename}/sin_split" sin_split_outputs = [f"{sin_split_name}/output_{i}" for i in range(num_splits)] - self.make_node("Split", [dyn_sin, mrope_splits_output_name], sin_split_outputs, name=sin_split_name, axis=-1) - + self.make_node( + "Split", + [dyn_sin, mrope_splits_output_name], + sin_split_outputs, + name=sin_split_name, + axis=-1, + ) + # Re-order the caches: [T, H, W, T, H, W] cos_reordered = [] sin_reordered = [] for i in range(num_splits): dim_chunk = self.mrope_splits[i] - cache_dim_to_use = i % 3 # 0 for T, 1 for H, 2 for W - + cache_dim_to_use = i % 3 # 0 for T, 1 for H, 2 for W + # Gather from dim 0 of the split cache chunk # input is [3, B, S, H_chunk], indices is [0, 1, or 2] gather_cos_name = f"{basename}/cos_gather_{i}" gather_cos_output = f"{gather_cos_name}/output_0" - self.make_node("Gather", [cos_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], [gather_cos_output], name=gather_cos_name, axis=0) - self.make_value(gather_cos_output, ir.DataType.FLOAT, [1, "batch_size", "sequence_length", dim_chunk]) # Shape [1, B, S, H_chunk] - + self.make_node( + "Gather", + [cos_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], + [gather_cos_output], + name=gather_cos_name, + axis=0, + ) + self.make_value( + gather_cos_output, + ir.DataType.FLOAT, + [1, "batch_size", "sequence_length", dim_chunk], + ) # Shape [1, B, S, H_chunk] + gather_sin_name = f"{basename}/sin_gather_{i}" gather_sin_output = f"{gather_sin_name}/output_0" - self.make_node("Gather", [sin_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], [gather_sin_output], name=gather_sin_name, axis=0) - self.make_value(gather_sin_output, ir.DataType.FLOAT, [1, "batch_size", "sequence_length", dim_chunk]) # Shape [1, B, S, H_chunk] + self.make_node( + "Gather", + [sin_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], + [gather_sin_output], + name=gather_sin_name, + axis=0, + ) + self.make_value( + gather_sin_output, + ir.DataType.FLOAT, + [1, "batch_size", "sequence_length", dim_chunk], + ) # Shape [1, B, S, H_chunk] # FIX: Squeeze the gathered cache to [B, S, H_chunk] squeeze_cos_name = f"{basename}/cos_squeeze_{i}" squeeze_cos_output = f"{squeeze_cos_name}/output_0" - self.make_squeeze(squeeze_cos_name, [gather_cos_output, "/model/constants/INT64/[0]"], ir.DataType.FLOAT, ["batch_size", "sequence_length", dim_chunk]) + self.make_squeeze( + squeeze_cos_name, + [gather_cos_output, "/model/constants/INT64/[0]"], + ir.DataType.FLOAT, + ["batch_size", "sequence_length", dim_chunk], + ) squeeze_sin_name = f"{basename}/sin_squeeze_{i}" squeeze_sin_output = f"{squeeze_sin_name}/output_0" - self.make_squeeze(squeeze_sin_name, [gather_sin_output, "/model/constants/INT64/[0]"], ir.DataType.FLOAT, ["batch_size", "sequence_length", dim_chunk]) - + self.make_squeeze( + squeeze_sin_name, + [gather_sin_output, "/model/constants/INT64/[0]"], + ir.DataType.FLOAT, + ["batch_size", "sequence_length", dim_chunk], + ) + # Unsqueeze to add the NumHeads dim: [B, 1, S, H_chunk] unsqueeze_cos_name = f"{basename}/cos_unsqueeze_{i}" unsqueeze_cos_output = f"{unsqueeze_cos_name}/output_0" - self.make_unsqueeze(unsqueeze_cos_name, [squeeze_cos_output, "/model/constants/INT64/[1]"], ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", dim_chunk]) + self.make_unsqueeze( + unsqueeze_cos_name, + [squeeze_cos_output, "/model/constants/INT64/[1]"], + ir.DataType.FLOAT, + ["batch_size", 1, "sequence_length", dim_chunk], + ) cos_reordered.append(unsqueeze_cos_output) - + unsqueeze_sin_name = f"{basename}/sin_unsqueeze_{i}" unsqueeze_sin_output = f"{unsqueeze_sin_name}/output_0" - self.make_unsqueeze(unsqueeze_sin_name, [squeeze_sin_output, "/model/constants/INT64/[1]"], ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", dim_chunk]) + self.make_unsqueeze( + unsqueeze_sin_name, + [squeeze_sin_output, "/model/constants/INT64/[1]"], + ir.DataType.FLOAT, + ["batch_size", 1, "sequence_length", dim_chunk], + ) sin_reordered.append(unsqueeze_sin_output) # Concat re-ordered chunks back to [B, 1, S, H] final_cos_concat_name = f"{basename}/cos_final_concat" final_cos_concat_output = f"{final_cos_concat_name}/output_0" - self.make_concat(final_cos_concat_name, cos_reordered, ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", self.head_size], axis=-1) + self.make_concat( + final_cos_concat_name, + cos_reordered, + ir.DataType.FLOAT, + ["batch_size", 1, "sequence_length", self.head_size], + axis=-1, + ) final_sin_concat_name = f"{basename}/sin_final_concat" final_sin_concat_output = f"{final_sin_concat_name}/output_0" - self.make_concat(final_sin_concat_name, sin_reordered, ir.DataType.FLOAT, ["batch_size", 1, "sequence_length", self.head_size], axis=-1) + self.make_concat( + final_sin_concat_name, + sin_reordered, + ir.DataType.FLOAT, + ["batch_size", 1, "sequence_length", self.head_size], + axis=-1, + ) # Caches (final_cos_concat_output, final_sin_concat_output) are now in float32 - + # Reshape input Q/K: [B, S, N*H] -> [B, N, S, H] reshape_1_name = f"{basename}/q_or_k_reshape_1" reshape_1_output = f"{reshape_1_name}/output_0" reshape_1_target_shape_onnx = f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]" - reshape_1_target_shape_ort = ["batch_size", "sequence_length", num_heads, self.head_size] - self.make_reshape(reshape_1_name, [q_or_k_path, reshape_1_target_shape_onnx], self.io_dtype, reshape_1_target_shape_ort) + reshape_1_target_shape_ort = [ + "batch_size", + "sequence_length", + num_heads, + self.head_size, + ] + self.make_reshape( + reshape_1_name, + [q_or_k_path, reshape_1_target_shape_onnx], + self.io_dtype, + reshape_1_target_shape_ort, + ) # Transpose Q/K: [B, S, N, H] -> [B, N, S, H] transpose_1_name = f"{basename}/q_or_k_transpose_1" transpose_1_output = f"{transpose_1_name}/output_0" - transpose_1_target_shape = ["batch_size", num_heads, "sequence_length", self.head_size] - self.make_transpose(transpose_1_name, reshape_1_output, self.io_dtype, transpose_1_target_shape, perm=[0, 2, 1, 3]) + transpose_1_target_shape = [ + "batch_size", + num_heads, + "sequence_length", + self.head_size, + ] + self.make_transpose( + transpose_1_name, + reshape_1_output, + self.io_dtype, + transpose_1_target_shape, + perm=[0, 2, 1, 3], + ) # --- Start RoPE computation --- q_or_k_compute_input = transpose_1_output @@ -400,40 +582,70 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Cast Q/K (self.io_dtype) up to float32 q_or_k_cast_name = f"{basename}/q_or_k_cast_fp32" q_or_k_cast_output = f"{q_or_k_cast_name}/output_0" - self.make_cast(q_or_k_cast_name, transpose_1_output, compute_dtype, transpose_1_target_shape) + self.make_cast( + q_or_k_cast_name, + transpose_1_output, + compute_dtype, + transpose_1_target_shape, + ) q_or_k_compute_input = q_or_k_cast_output elif not force_fp32 and self.io_dtype != ir.DataType.FLOAT: # Cast Caches (float32) down to self.io_dtype cos_cache_cast_name = f"{basename}/cos_final_cast" cos_cache_cast_output = f"{cos_cache_cast_name}/output_0" - self.make_cast(cos_cache_cast_name, final_cos_concat_output, compute_dtype, ["batch_size", 1, "sequence_length", self.head_size]) + self.make_cast( + cos_cache_cast_name, + final_cos_concat_output, + compute_dtype, + ["batch_size", 1, "sequence_length", self.head_size], + ) cos_cache_compute_input = cos_cache_cast_output sin_cache_cast_name = f"{basename}/sin_final_cast" sin_cache_cast_output = f"{sin_cache_cast_name}/output_0" - self.make_cast(sin_cache_cast_name, final_sin_concat_output, compute_dtype, ["batch_size", 1, "sequence_length", self.head_size]) + self.make_cast( + sin_cache_cast_name, + final_sin_concat_output, + compute_dtype, + ["batch_size", 1, "sequence_length", self.head_size], + ) sin_cache_compute_input = sin_cache_cast_output # Apply rotation: (q * cos) + (rotate_half(q) * sin) - + # 1. (q * cos) mul_1_name = f"{basename}/mul_1" mul_1_output = f"{mul_1_name}/output_0" - self.make_mul(mul_1_name, [q_or_k_compute_input, cos_cache_compute_input], compute_dtype, transpose_1_target_shape) - + self.make_mul( + mul_1_name, + [q_or_k_compute_input, cos_cache_compute_input], + compute_dtype, + transpose_1_target_shape, + ) + # 2. rotate_half(q) rotated_half_q_name = self.rotate_half(q_or_k_compute_input, transpose_1_target_shape, basename, compute_dtype) - + # 3. (rotate_half(q) * sin) mul_2_name = f"{basename}/mul_2" mul_2_output = f"{mul_2_name}/output_0" - self.make_mul(mul_2_name, [rotated_half_q_name, sin_cache_compute_input], compute_dtype, transpose_1_target_shape) + self.make_mul( + mul_2_name, + [rotated_half_q_name, sin_cache_compute_input], + compute_dtype, + transpose_1_target_shape, + ) # 4. (q * cos) + (rotate_half(q) * sin) add_name = f"{basename}/add" add_output = f"{add_name}/output_0" - self.make_add(add_name, [mul_1_output, mul_2_output], compute_dtype, transpose_1_target_shape) - + self.make_add( + add_name, + [mul_1_output, mul_2_output], + compute_dtype, + transpose_1_target_shape, + ) + # --- End RoPE computation --- add_output_final = add_output @@ -447,31 +659,48 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Transpose back: [B, N, S, H] -> [B, S, N, H] transpose_2_name = f"{basename}/q_or_k_transpose_2" transpose_2_output = f"{transpose_2_name}/output_0" - self.make_transpose(transpose_2_name, add_output_final, output_dtype, reshape_1_target_shape_ort, perm=[0, 2, 1, 3]) - + self.make_transpose( + transpose_2_name, + add_output_final, + output_dtype, + reshape_1_target_shape_ort, + perm=[0, 2, 1, 3], + ) + # Reshape back: [B, S, N, H] -> [B, S, N*H] reshape_2_name = f"{basename}/q_or_k_reshape_2" reshape_2_output = f"{reshape_2_name}/output_0" - self.make_reshape(reshape_2_name, [transpose_2_output, f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]"], output_dtype, q_or_k_shape) - + self.make_reshape( + reshape_2_name, + [ + transpose_2_output, + f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]", + ], + output_dtype, + q_or_k_shape, + ) + return reshape_2_output def make_attention(self, layer_id, attention, root_input, **kwargs): - # 1. Unpack QKV if necessary (e.g. qkv_proj) super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) - + # 2. Build Q/K/V MatMul and Add nodes q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" q_matmul_name = self.make_matmul(attention.q_proj, q_matmul_basename, root_input) self.attention_attrs["q_path"] = f"{q_matmul_name}/output_0" - q_shape = ["batch_size", "sequence_length", self.num_attn_heads * self.head_size] - + q_shape = [ + "batch_size", + "sequence_length", + self.num_attn_heads * self.head_size, + ] + k_matmul_basename = f"/model/layers.{layer_id}/attn/k_proj/MatMul" k_matmul_name = self.make_matmul(attention.k_proj, k_matmul_basename, root_input) self.attention_attrs["k_path"] = f"{k_matmul_name}/output_0" k_shape = ["batch_size", "sequence_length", self.num_kv_heads * self.head_size] - + v_matmul_basename = f"/model/layers.{layer_id}/attn/v_proj/MatMul" v_matmul_name = self.make_matmul(attention.v_proj, v_matmul_basename, root_input) self.attention_attrs["v_path"] = f"{v_matmul_name}/output_0" @@ -483,20 +712,34 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): if q_bias_exists: q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias(attention.q_proj.bias, q_add_name, root_input=self.attention_attrs["q_path"]) + self.make_add_bias( + attention.q_proj.bias, + q_add_name, + root_input=self.attention_attrs["q_path"], + ) self.attention_attrs["q_path"] = f"{q_add_name}/output_0" if k_bias_exists: k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias(attention.k_proj.bias, k_add_name, root_input=self.attention_attrs["k_path"]) + self.make_add_bias( + attention.k_proj.bias, + k_add_name, + root_input=self.attention_attrs["k_path"], + ) self.attention_attrs["k_path"] = f"{k_add_name}/output_0" if v_bias_exists: v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"]) + self.make_add_bias( + attention.v_proj.bias, + v_add_name, + root_input=self.attention_attrs["v_path"], + ) self.attention_attrs["v_path"] = f"{v_add_name}/output_0" # 3. Apply 3D RoPE (MRoPE) - cos_dynamic, sin_dynamic = self.make_dynamic_rope_caches(layer_id, basename=f"/model/layers.{layer_id}/attn/mrope_dynamic_cache") - + cos_dynamic, sin_dynamic = self.make_dynamic_rope_caches( + layer_id, basename=f"/model/layers.{layer_id}/attn/mrope_dynamic_cache" + ) + # Apply rotation to Q self.attention_attrs["q_path"] = self.apply_mrope_rotation( layer_id, @@ -505,9 +748,9 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): cos_dynamic, sin_dynamic, self.num_attn_heads, - basename=f"/model/layers.{layer_id}/attn/q_mrope" + basename=f"/model/layers.{layer_id}/attn/q_mrope", ) - + # Apply rotation to K self.attention_attrs["k_path"] = self.apply_mrope_rotation( layer_id, @@ -516,7 +759,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): cos_dynamic, sin_dynamic, self.num_kv_heads, - basename=f"/model/layers.{layer_id}/attn/k_mrope" + basename=f"/model/layers.{layer_id}/attn/k_mrope", ) # 4. Call GroupQueryAttention op @@ -527,22 +770,22 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" self.make_attention_op( - attn_name, - q_path=self.attention_attrs["q_path"], - k_path=self.attention_attrs["k_path"], + attn_name, + q_path=self.attention_attrs["q_path"], + k_path=self.attention_attrs["k_path"], v_path=self.attention_attrs["v_path"], - past_k=past_k, - past_v=past_v, - present_k=present_k, + past_k=past_k, + past_v=past_v, + present_k=present_k, present_v=present_v, # Pass empty strings for fused caches since we applied RoPE manually - cos_cache="", - sin_cache="", + cos_cache="", + sin_cache="", **kwargs, ) # 5. Build O-proj - o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' + o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense" o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" o_weight = getattr(attention, o_proj) o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") @@ -556,35 +799,35 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): else: self.layernorm_attrs["skip_input"] = f"{o_matmul_name}/output_0" - def make_model(self, input_path, config=None): - + def make_model(self, input_path): # Make inputs and outputs to ONNX model self.make_inputs_and_outputs() # Make pre-processing nodes self.make_preprocessing_nodes() - + # Load the Hugging Face model from transformers import Qwen2_5_VLForConditionalGeneration + print("Loading Qwen2_5_VLForConditionalGeneration model...") hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - self.model_name_or_path, - config=config, - cache_dir=self.cache_dir, - token=self.hf_token, - trust_remote_code=self.hf_remote + self.model_name_or_path, + config=self.config, + cache_dir=self.cache_dir, + token=self.hf_token, + trust_remote_code=self.hf_remote, ) - + # We only want to export the text model model = hf_model.language_model print(f"Isolated language_model ({model.__class__.__name__}) for ONNX export.") # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 - + # The base.Model.make_model() loop expects modules from a standard causal LM, # so we replicate its logic here but point to the correct modules in the hf_model - + # Handle Embeddings if not self.exclude_embeds: print("Reading embedding layer") @@ -595,19 +838,25 @@ def make_model(self, input_path, config=None): print("Skipping embedding layer, model will expect 'inputs_embeds'.") self.layernorm_attrs["root_input"] = "inputs_embeds" self.layernorm_attrs["skip_input"] = "inputs_embeds" - + # Handle Decoder Layers for layer in model.layers: if self.layer_id < self.num_layers: print(f"Reading decoder layer {self.layer_id}") self.make_layer(self.layer_id, layer) self.layer_id += 1 - + # Handle Final Norm if self.layer_id == self.num_layers and hasattr(model, "norm"): print("Reading final norm") - self.make_layernorm(self.layer_id, model.norm, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") - + self.make_layernorm( + self.layer_id, + model.norm, + skip=True, + simple=self.layernorm_attrs["simple"], + location="final_norm", + ) + # Handle LM Head if not self.exclude_lm_head: # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model diff --git a/src/python/py/models/test_vl.py b/src/python/py/models/test_vl.py new file mode 100644 index 0000000000..7ef151e2de --- /dev/null +++ b/src/python/py/models/test_vl.py @@ -0,0 +1,61 @@ +from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info + +# default: Load the model on the available device(s) +model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto", cache_dir="/home/tlwu/git/onnxruntime-genai/src/python/py/models/qwen2.5_vl_7b_instruct" +) + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. +# model = Qwen2_5_VLForConditionalGeneration.from_pretrained( +# "Qwen/Qwen2.5-VL-7B-Instruct", +# torch_dtype=torch.bfloat16, +# attn_implementation="flash_attention_2", +# device_map="auto", +# ) + +# default processer +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + +# The default range for the number of visual tokens per image in the model is 4-16384. +# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost. +# min_pixels = 256*28*28 +# max_pixels = 1280*28*28 +# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) + +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } +] + +# Preparation for inference +text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) +inputs = inputs.to("cuda") + +# Inference: Generation of the output +generated_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) +] +output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index c0ab3c6d01..2c05b4acc0 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -3,14 +3,15 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -import os import argparse -import torch +import os + import numpy as np import onnxruntime as ort +import torch from onnx import TensorProto from transformers import Qwen2_5_VLForConditionalGeneration -from typing import Tuple, Dict, List + def torch_dtype_to_onnx_tensor_proto(dtype: torch.dtype) -> int: """Maps torch.dtype to onnx.TensorProto.DataType""" @@ -28,6 +29,7 @@ def torch_dtype_to_onnx_tensor_proto(dtype: torch.dtype) -> int: return TensorProto.BOOL raise ValueError(f"Unsupported torch dtype: {dtype}") + def to_numpy(tensor): """Move tensor to CPU and convert to numpy, handling bf16.""" if tensor.dtype == torch.bfloat16: @@ -35,62 +37,55 @@ def to_numpy(tensor): return tensor.detach().cpu().to(torch.float32).numpy() return tensor.detach().cpu().numpy() + def compare_outputs( hf_logits: torch.Tensor, - ort_logits: torch.Tensor, # Changed to torch.Tensor - hf_presents: List[Tuple[torch.Tensor, torch.Tensor]], - ort_presents: List[torch.Tensor], # Changed to list[torch.Tensor] + ort_logits: torch.Tensor, # Changed to torch.Tensor + hf_presents: list[tuple[torch.Tensor, torch.Tensor]], + ort_presents: list[torch.Tensor], # Changed to list[torch.Tensor] step_name: str, rtol: float, - atol: float + atol: float, ): """Compares logits and KV cache outputs using numpy.""" - + print(f"--- Comparing {step_name} Logits ---") - + # We can use to_numpy safely here because we'll compare fp32 vs fp32 # or (bf16->fp32) vs (bf16->fp32) - np.testing.assert_allclose( - to_numpy(hf_logits), - to_numpy(ort_logits), - rtol=rtol, - atol=atol - ) + np.testing.assert_allclose(to_numpy(hf_logits), to_numpy(ort_logits), rtol=rtol, atol=atol) print("Logits: PASS") print(f"\n--- Comparing {step_name} KV Cache ---") # hf_presents is now a list of tuples: [(k0, v0), (k1, v1), ...] # Flatten it to a list: [k0, v0, k1, v1, ...] hf_presents_list = [t for layer_kv in hf_presents for t in layer_kv] - - assert len(hf_presents_list) == len(ort_presents), \ + + assert len(hf_presents_list) == len(ort_presents), ( f"HF presents count ({len(hf_presents_list)}) != ORT presents count ({len(ort_presents)})" + ) - for i in range(len(hf_presents_list)): + for i in range(len(hf_presents_list)): hf_tensor = hf_presents_list[i] ort_tensor = ort_presents[i] - - np.testing.assert_allclose( - to_numpy(hf_tensor), - to_numpy(ort_tensor), - rtol=rtol, - atol=atol - ) + + np.testing.assert_allclose(to_numpy(hf_tensor), to_numpy(ort_tensor), rtol=rtol, atol=atol) print(f"KV Cache (all {len(hf_presents_list)} tensors): PASS") print(f"\nāœ… {step_name} Parity Test Passed!\n") + def ort_io_binding_helper( sess: ort.InferenceSession, - input_tensors: Dict[str, torch.Tensor], - output_tensors: Dict[str, torch.Tensor], - device: str + input_tensors: dict[str, torch.Tensor], + output_tensors: dict[str, torch.Tensor], + device: str, ) -> None: """ Binds torch tensors to an ONNX Runtime IOBinding object and runs the session. Tensors must be on the correct device (e.g., 'cuda:0'). """ bind = sess.io_binding() - + # Get device type and index for ORT ort_device = device.split(":")[0] ort_device_id = 0 @@ -101,50 +96,57 @@ def ort_io_binding_helper( if not tensor.is_contiguous(): print(f"Warning: Input tensor {name} is not contiguous. Making it contiguous.") tensor = tensor.contiguous() - input_tensors[name] = tensor # Update dict entry for future runs (decode) - + input_tensors[name] = tensor # Update dict entry for future runs (decode) + bind.bind_input( name, ort_device, ort_device_id, torch_dtype_to_onnx_tensor_proto(tensor.dtype), tensor.shape, - tensor.data_ptr() + tensor.data_ptr(), ) - + for name, tensor in output_tensors.items(): if not tensor.is_contiguous(): print(f"Warning: Output tensor {name} is not contiguous. Making it contiguous.") tensor = tensor.contiguous() - output_tensors[name] = tensor # Update dict entry - + output_tensors[name] = tensor # Update dict entry + bind.bind_output( name, ort_device, ort_device_id, torch_dtype_to_onnx_tensor_proto(tensor.dtype), tensor.shape, - tensor.data_ptr() + tensor.data_ptr(), ) - + sess.run_with_iobinding(bind) -def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gpu: bool, use_bf16: bool, use_fp16: bool): +def test_parity( + hf_model_name: str, + cache_dir: str, + onnx_model_path: str, + use_gpu: bool, + use_bf16: bool, + use_fp16: bool, +): """ Runs a two-step (prefill and decode) parity test between the Hugging Face and ONNX models. """ - + print(f"Loading Hugging Face model: {hf_model_name}") print("This requires `trust_remote_code=True`.") - + if not use_gpu: print("ERROR: This test script now requires a GPU (`--cpu` is not supported) due to IOBinding.") return - device = "cuda:0" # IOBinding needs the specific device ID - + device = "cuda:0" # IOBinding needs the specific device ID + if use_bf16: torch_dtype = torch.bfloat16 # Standard BF16 tolerances @@ -157,30 +159,34 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp torch_dtype = torch.float32 # Standard FP32 tolerances rtol, atol = 1e-1, 1e-1 - + allow_bf16_logits = os.getenv("allow_bf16_logits") in ["1", "true", "True"] - + if allow_bf16_logits: logits_dtype = torch_dtype else: # The builder script (base.Model) upcasts logits to float32 - # ONLY when the io_dtype is bfloat16. + # ONLY when the io_dtype is bfloat16. # For FP16 or FP32, it keeps the original dtype. logits_dtype = torch.float32 if use_bf16 else torch_dtype print(f"Allocating ONNX logits output buffer with dtype: {logits_dtype}") - - hf_full_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - hf_model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - cache_dir=cache_dir - ).to(device).eval() - + + hf_full_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + hf_model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + cache_dir=cache_dir, + ) + .to(device) + .eval() + ) + # The ONNX model is *only* the language_model component hf_text_model = hf_full_model.language_model config = hf_text_model.config - + # Get model parameters BATCH_SIZE = 1 PREFILL_LEN = 10 @@ -189,8 +195,8 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp NUM_LAYERS = config.num_hidden_layers NUM_KV_HEADS = config.num_key_value_heads HEAD_DIM = config.hidden_size // config.num_attention_heads - VOCAB_SIZE = config.vocab_size # Get vocab size for output - + VOCAB_SIZE = config.vocab_size # Get vocab size for output + print("\n--- Model Parameters ---") print(f"Device: {device}") print(f"DType: {torch_dtype}") @@ -204,41 +210,35 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp print(f"Loading ONNX model: {onnx_model_path}") providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] sess = ort.InferenceSession(onnx_model_path, providers=providers) - + # ================================================================= # 1. PREFILL STEP # ================================================================= print(f"Running Prefill Step (Sequence Length = {PREFILL_LEN})...") - + # --- Create HF/Torch Inputs --- # Use randn (normal distribution) scaled down for better stability in FP16 # inputs_embeds are normally centered around 0, unlike rand which is [0, 1] - inputs_embeds_prefill = torch.randn( - (BATCH_SIZE, PREFILL_LEN, HIDDEN_SIZE), - dtype=torch_dtype, - device=device - ) * 0.001 - - # Qwen2.5-VL uses 3D position IDs (temporal, height, width). + inputs_embeds_prefill = ( + torch.randn((BATCH_SIZE, PREFILL_LEN, HIDDEN_SIZE), dtype=torch_dtype, device=device) * 0.001 + ) + + # Qwen2.5-VL uses 3D position IDs (temporal, height, width). # For text tokens, all three dimensions typically use the same sequence index. pos_ids_1d_prefill = torch.arange(PREFILL_LEN, device=device).expand(BATCH_SIZE, -1) position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1) - - attention_mask_prefill = torch.ones( - (BATCH_SIZE, PREFILL_LEN), - dtype=torch.int64, - device=device - ) - + + attention_mask_prefill = torch.ones((BATCH_SIZE, PREFILL_LEN), dtype=torch.int64, device=device) + cache_position_prefill = torch.arange(PREFILL_LEN, device=device) - + # --- Create ONNX Input Tensors (on device) --- ort_inputs_prefill = { "inputs_embeds": inputs_embeds_prefill, "position_ids": position_ids_prefill, - "attention_mask": attention_mask_prefill + "attention_mask": attention_mask_prefill, } - + # Create dummy pasts with 0 sequence length past_shape = (BATCH_SIZE, NUM_KV_HEADS, 0, HEAD_DIM) dummy_past = torch.empty(past_shape, dtype=torch_dtype, device=device) @@ -247,15 +247,11 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp ort_inputs_prefill[f"past_key_values.{i}.value"] = dummy_past # --- Create ONNX Output Tensors (on device) --- - ort_logits_prefill = torch.empty( - (BATCH_SIZE, PREFILL_LEN, VOCAB_SIZE), - dtype=logits_dtype, - device=device - ) + ort_logits_prefill = torch.empty((BATCH_SIZE, PREFILL_LEN, VOCAB_SIZE), dtype=logits_dtype, device=device) ort_presents_prefill = [] ort_outputs_prefill = {"logits": ort_logits_prefill} present_shape = (BATCH_SIZE, NUM_KV_HEADS, PREFILL_LEN, HEAD_DIM) - + for i in range(NUM_LAYERS): ort_present_k = torch.empty(present_shape, dtype=torch_dtype, device=device) ort_present_v = torch.empty(present_shape, dtype=torch_dtype, device=device) @@ -272,7 +268,7 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp past_key_values=None, cache_position=cache_position_prefill, return_dict=True, - use_cache=True + use_cache=True, ) # --- Run ONNX Model with IOBinding --- @@ -284,69 +280,58 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp compare_outputs( hf_logits_prefill, - ort_logits_prefill, # This is the tensor we pre-allocated + ort_logits_prefill, # This is the tensor we pre-allocated hf_presents_prefill, - ort_presents_prefill, # This is the list of tensors we pre-allocated + ort_presents_prefill, # This is the list of tensors we pre-allocated step_name="Prefill", rtol=rtol, - atol=atol + atol=atol, ) # ================================================================= # 2. DECODE STEP # ================================================================= print(f"Running Decode Step (Sequence Length = {DECODE_LEN})...") - + # --- Create HF/Torch Inputs --- # Use randn (normal distribution) scaled down - inputs_embeds_decode = torch.randn( - (BATCH_SIZE, DECODE_LEN, HIDDEN_SIZE), - dtype=torch_dtype, - device=device - ) * 0.001 - + inputs_embeds_decode = torch.randn((BATCH_SIZE, DECODE_LEN, HIDDEN_SIZE), dtype=torch_dtype, device=device) * 0.001 + # Position IDs continue from prefill length - pos_ids_1d_decode = torch.tensor( - [[PREFILL_LEN]], - dtype=torch.int64, - device=device - ) + pos_ids_1d_decode = torch.tensor([[PREFILL_LEN]], dtype=torch.int64, device=device) position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1) - - attention_mask_decode = torch.ones( - (BATCH_SIZE, PREFILL_LEN + DECODE_LEN), - dtype=torch.int64, - device=device - ) - + + attention_mask_decode = torch.ones((BATCH_SIZE, PREFILL_LEN + DECODE_LEN), dtype=torch.int64, device=device) + cache_position_decode = torch.tensor([PREFILL_LEN], device=device) - + # Use the KV cache from the HF prefill run hf_past_key_values = hf_outputs_prefill.past_key_values - + # --- Create ONNX Input Tensors (on device) --- ort_inputs_decode = { "inputs_embeds": inputs_embeds_decode, "position_ids": position_ids_decode, - "attention_mask": attention_mask_decode + "attention_mask": attention_mask_decode, } - + # Use the KV cache from the ONNX prefill run (these are already torch tensors) for i in range(NUM_LAYERS): - ort_inputs_decode[f"past_key_values.{i}.key"] = ort_presents_prefill[i*2] - ort_inputs_decode[f"past_key_values.{i}.value"] = ort_presents_prefill[i*2 + 1] - + ort_inputs_decode[f"past_key_values.{i}.key"] = ort_presents_prefill[i * 2] + ort_inputs_decode[f"past_key_values.{i}.value"] = ort_presents_prefill[i * 2 + 1] + # --- Create ONNX Output Tensors (on device) --- # --- FIX: Logits from bf16 ONNX model are intentionally float32 for accuracy --- - ort_logits_decode = torch.empty( - (BATCH_SIZE, DECODE_LEN, VOCAB_SIZE), - dtype=logits_dtype, - device=device - ) + ort_logits_decode = torch.empty((BATCH_SIZE, DECODE_LEN, VOCAB_SIZE), dtype=logits_dtype, device=device) ort_presents_decode = [] ort_outputs_decode = {"logits": ort_logits_decode} - present_shape_decode = (BATCH_SIZE, NUM_KV_HEADS, PREFILL_LEN + DECODE_LEN, HEAD_DIM) - + present_shape_decode = ( + BATCH_SIZE, + NUM_KV_HEADS, + PREFILL_LEN + DECODE_LEN, + HEAD_DIM, + ) + for i in range(NUM_LAYERS): ort_present_k = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) ort_present_v = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) @@ -363,7 +348,7 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp past_key_values=hf_past_key_values, cache_position=cache_position_decode, return_dict=True, - use_cache=True + use_cache=True, ) # --- Run ONNX Model with IOBinding --- @@ -372,7 +357,7 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp # --- Compare Decode --- hf_logits_decode = hf_full_model.lm_head(hf_outputs_decode.last_hidden_state) hf_presents_decode = hf_outputs_decode.past_key_values - + compare_outputs( hf_logits_decode, ort_logits_decode, @@ -380,12 +365,13 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp ort_presents_decode, step_name="Decode", rtol=rtol, - atol=atol + atol=atol, ) - print("="*30) + print("=" * 30) print("šŸŽ‰ All Parity Tests Passed! šŸŽ‰") - print("="*30) + print("=" * 30) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Parity test for Qwen 2.5 VL ONNX model.") @@ -393,41 +379,33 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp "--hf_model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", - help="Path or name of the Hugging Face model." + help="Path or name of the Hugging Face model.", ) parser.add_argument( "--onnx_model", type=str, required=True, - help="Path to the exported ONNX model file." + help="Path to the exported ONNX model file.", ) parser.add_argument( "--cache_dir", type=str, default="./qwen2.5_vl_7b_instruct", - help="Path to the cache directory." + help="Path to the cache directory.", ) - + parser.add_argument( "--cpu", action="store_true", - help="Force running the test on CPU (Not supported with IOBinding)." + help="Force running the test on CPU (Not supported with IOBinding).", ) - parser.add_argument( - "--bf16", - action="store_true", - help="Use bf16 precision." - ) - - parser.add_argument( - "--fp16", - action="store_true", - help="Use fp16 precision." - ) + parser.add_argument("--bf16", action="store_true", help="Use bf16 precision.") + + parser.add_argument("--fp16", action="store_true", help="Use fp16 precision.") args = parser.parse_args() - + if args.cpu and (args.bf16 or args.fp16): print("Warning: Cannot run bf16/fp16 on CPU. Forcing float32.") args.bf16 = False @@ -436,12 +414,12 @@ def test_parity(hf_model_name: str, cache_dir: str, onnx_model_path: str, use_gp if args.cpu: print("Warning: CPU testing with IOBinding is not set up. Forcing GPU.") # This script is now GPU-only - + test_parity( - hf_model_name=args.hf_model, + hf_model_name=args.hf_model, cache_dir=args.cache_dir, - onnx_model_path=args.onnx_model, - use_gpu=True, # Forcing GPU + onnx_model_path=args.onnx_model, + use_gpu=True, # Forcing GPU use_bf16=args.bf16, - use_fp16=args.fp16 - ) \ No newline at end of file + use_fp16=args.fp16, + ) From f3a34b624ca3fba12b2e30de4577322740f01b00 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 19:25:34 +0000 Subject: [PATCH 06/18] clean up --- src/python/py/models/test_vl.py | 61 ------------------- .../models/qwen_2.5_vl/test_qwen_2.5_vl.py | 8 +-- 2 files changed, 3 insertions(+), 66 deletions(-) delete mode 100644 src/python/py/models/test_vl.py diff --git a/src/python/py/models/test_vl.py b/src/python/py/models/test_vl.py deleted file mode 100644 index 7ef151e2de..0000000000 --- a/src/python/py/models/test_vl.py +++ /dev/null @@ -1,61 +0,0 @@ -from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor -from qwen_vl_utils import process_vision_info - -# default: Load the model on the available device(s) -model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto", cache_dir="/home/tlwu/git/onnxruntime-genai/src/python/py/models/qwen2.5_vl_7b_instruct" -) - -# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. -# model = Qwen2_5_VLForConditionalGeneration.from_pretrained( -# "Qwen/Qwen2.5-VL-7B-Instruct", -# torch_dtype=torch.bfloat16, -# attn_implementation="flash_attention_2", -# device_map="auto", -# ) - -# default processer -processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - -# The default range for the number of visual tokens per image in the model is 4-16384. -# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost. -# min_pixels = 256*28*28 -# max_pixels = 1280*28*28 -# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) - -messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - } -] - -# Preparation for inference -text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) -image_inputs, video_inputs = process_vision_info(messages) -inputs = processor( - text=[text], - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", -) -inputs = inputs.to("cuda") - -# Inference: Generation of the output -generated_ids = model.generate(**inputs, max_new_tokens=128) -generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) -] -output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False -) -print(output_text) diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index 2c05b4acc0..a1d9e98622 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -95,8 +95,7 @@ def ort_io_binding_helper( for name, tensor in input_tensors.items(): if not tensor.is_contiguous(): print(f"Warning: Input tensor {name} is not contiguous. Making it contiguous.") - tensor = tensor.contiguous() - input_tensors[name] = tensor # Update dict entry for future runs (decode) + input_tensors[name] = tensor.contiguous() bind.bind_input( name, @@ -110,8 +109,7 @@ def ort_io_binding_helper( for name, tensor in output_tensors.items(): if not tensor.is_contiguous(): print(f"Warning: Output tensor {name} is not contiguous. Making it contiguous.") - tensor = tensor.contiguous() - output_tensors[name] = tensor # Update dict entry + output_tensors[name] = tensor.contiguous() bind.bind_output( name, @@ -160,7 +158,7 @@ def test_parity( # Standard FP32 tolerances rtol, atol = 1e-1, 1e-1 - allow_bf16_logits = os.getenv("allow_bf16_logits") in ["1", "true", "True"] + allow_bf16_logits = os.getenv("ALLOW_BF16_LOGITS") in ["1", "true", "True"] if allow_bf16_logits: logits_dtype = torch_dtype From 0d3a85e0d2ee9c1e19647ed99edcbc5d320f56b6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 19:37:18 +0000 Subject: [PATCH 07/18] format --- src/python/py/models/builders/qwen.py | 7 +- .../models/qwen_2.5_vl/test_qwen_2.5_vl.py | 68 +++++++++---------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 12eff67679..107b851df7 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -8,6 +8,7 @@ import onnx_ir as ir import torch +from transformers import Qwen2_5_VLForConditionalGeneration from .base import Model @@ -82,7 +83,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # SOLUTION: We must override the base model's decision and set the # output logits type to match the io_dtype (bfloat16). # - self.allow_bf16_logits = os.getenv("allow_bf16_logits") in ["1", "true", "True"] + self.allow_bf16_logits = os.getenv("ALLOW_BF16_LOGITS") in ["1", "true", "True"] if self.allow_bf16_logits and self.io_dtype == ir.DataType.BFLOAT16: print("Fixing output logits precision. Setting output_types['logits'] to BFLOAT16 to match HF model.") self.output_types["logits"] = ir.DataType.BFLOAT16 @@ -363,7 +364,7 @@ def rotate_half(self, x_name, x_shape, basename, compute_dtype): axis=-1, num_outputs=2, ) - half_shape = x_shape[:-1] + [x_shape[-1] // 2] + half_shape = [*x_shape[:-1], x_shape[-1] // 2] self.make_value(split_output_0, compute_dtype, half_shape) self.make_value(split_output_1, compute_dtype, half_shape) @@ -807,8 +808,6 @@ def make_model(self, input_path): self.make_preprocessing_nodes() # Load the Hugging Face model - from transformers import Qwen2_5_VLForConditionalGeneration - print("Loading Qwen2_5_VLForConditionalGeneration model...") hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_name_or_path, diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index a1d9e98622..82f03133e0 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -186,23 +186,23 @@ def test_parity( config = hf_text_model.config # Get model parameters - BATCH_SIZE = 1 - PREFILL_LEN = 10 - DECODE_LEN = 1 - HIDDEN_SIZE = config.hidden_size - NUM_LAYERS = config.num_hidden_layers - NUM_KV_HEADS = config.num_key_value_heads - HEAD_DIM = config.hidden_size // config.num_attention_heads - VOCAB_SIZE = config.vocab_size # Get vocab size for output + batch_size = 1 + prefill_len = 10 + decode_len = 1 + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + vocab_size = config.vocab_size # Get vocab size for output print("\n--- Model Parameters ---") print(f"Device: {device}") print(f"DType: {torch_dtype}") print(f"RTOL: {rtol}, ATOL: {atol}") - print(f"Layers: {NUM_LAYERS}") - print(f"Hidden Size: {HIDDEN_SIZE}") - print(f"KV Heads: {NUM_KV_HEADS}") - print(f"Head Dim: {HEAD_DIM}") + print(f"Layers: {num_layers}") + print(f"Hidden Size: {hidden_size}") + print(f"KV Heads: {num_kv_heads}") + print(f"Head Dim: {head_dim}") print("------------------------\n") print(f"Loading ONNX model: {onnx_model_path}") @@ -212,23 +212,23 @@ def test_parity( # ================================================================= # 1. PREFILL STEP # ================================================================= - print(f"Running Prefill Step (Sequence Length = {PREFILL_LEN})...") + print(f"Running Prefill Step (Sequence Length = {prefill_len})...") # --- Create HF/Torch Inputs --- # Use randn (normal distribution) scaled down for better stability in FP16 # inputs_embeds are normally centered around 0, unlike rand which is [0, 1] inputs_embeds_prefill = ( - torch.randn((BATCH_SIZE, PREFILL_LEN, HIDDEN_SIZE), dtype=torch_dtype, device=device) * 0.001 + torch.randn((batch_size, prefill_len, hidden_size), dtype=torch_dtype, device=device) * 0.001 ) # Qwen2.5-VL uses 3D position IDs (temporal, height, width). # For text tokens, all three dimensions typically use the same sequence index. - pos_ids_1d_prefill = torch.arange(PREFILL_LEN, device=device).expand(BATCH_SIZE, -1) + pos_ids_1d_prefill = torch.arange(prefill_len, device=device).expand(batch_size, -1) position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1) - attention_mask_prefill = torch.ones((BATCH_SIZE, PREFILL_LEN), dtype=torch.int64, device=device) + attention_mask_prefill = torch.ones((batch_size, prefill_len), dtype=torch.int64, device=device) - cache_position_prefill = torch.arange(PREFILL_LEN, device=device) + cache_position_prefill = torch.arange(prefill_len, device=device) # --- Create ONNX Input Tensors (on device) --- ort_inputs_prefill = { @@ -238,19 +238,19 @@ def test_parity( } # Create dummy pasts with 0 sequence length - past_shape = (BATCH_SIZE, NUM_KV_HEADS, 0, HEAD_DIM) + past_shape = (batch_size, num_kv_heads, 0, head_dim) dummy_past = torch.empty(past_shape, dtype=torch_dtype, device=device) - for i in range(NUM_LAYERS): + for i in range(num_layers): ort_inputs_prefill[f"past_key_values.{i}.key"] = dummy_past ort_inputs_prefill[f"past_key_values.{i}.value"] = dummy_past # --- Create ONNX Output Tensors (on device) --- - ort_logits_prefill = torch.empty((BATCH_SIZE, PREFILL_LEN, VOCAB_SIZE), dtype=logits_dtype, device=device) + ort_logits_prefill = torch.empty((batch_size, prefill_len, vocab_size), dtype=logits_dtype, device=device) ort_presents_prefill = [] ort_outputs_prefill = {"logits": ort_logits_prefill} - present_shape = (BATCH_SIZE, NUM_KV_HEADS, PREFILL_LEN, HEAD_DIM) + present_shape = (batch_size, num_kv_heads, prefill_len, head_dim) - for i in range(NUM_LAYERS): + for i in range(num_layers): ort_present_k = torch.empty(present_shape, dtype=torch_dtype, device=device) ort_present_v = torch.empty(present_shape, dtype=torch_dtype, device=device) ort_outputs_prefill[f"present.{i}.key"] = ort_present_k @@ -289,19 +289,19 @@ def test_parity( # ================================================================= # 2. DECODE STEP # ================================================================= - print(f"Running Decode Step (Sequence Length = {DECODE_LEN})...") + print(f"Running Decode Step (Sequence Length = {decode_len})...") # --- Create HF/Torch Inputs --- # Use randn (normal distribution) scaled down - inputs_embeds_decode = torch.randn((BATCH_SIZE, DECODE_LEN, HIDDEN_SIZE), dtype=torch_dtype, device=device) * 0.001 + inputs_embeds_decode = torch.randn((batch_size, decode_len, hidden_size), dtype=torch_dtype, device=device) * 0.001 # Position IDs continue from prefill length - pos_ids_1d_decode = torch.tensor([[PREFILL_LEN]], dtype=torch.int64, device=device) + pos_ids_1d_decode = torch.tensor([[prefill_len]], dtype=torch.int64, device=device) position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1) - attention_mask_decode = torch.ones((BATCH_SIZE, PREFILL_LEN + DECODE_LEN), dtype=torch.int64, device=device) + attention_mask_decode = torch.ones((batch_size, prefill_len + decode_len), dtype=torch.int64, device=device) - cache_position_decode = torch.tensor([PREFILL_LEN], device=device) + cache_position_decode = torch.tensor([prefill_len], device=device) # Use the KV cache from the HF prefill run hf_past_key_values = hf_outputs_prefill.past_key_values @@ -314,23 +314,23 @@ def test_parity( } # Use the KV cache from the ONNX prefill run (these are already torch tensors) - for i in range(NUM_LAYERS): + for i in range(num_layers): ort_inputs_decode[f"past_key_values.{i}.key"] = ort_presents_prefill[i * 2] ort_inputs_decode[f"past_key_values.{i}.value"] = ort_presents_prefill[i * 2 + 1] # --- Create ONNX Output Tensors (on device) --- # --- FIX: Logits from bf16 ONNX model are intentionally float32 for accuracy --- - ort_logits_decode = torch.empty((BATCH_SIZE, DECODE_LEN, VOCAB_SIZE), dtype=logits_dtype, device=device) + ort_logits_decode = torch.empty((batch_size, decode_len, vocab_size), dtype=logits_dtype, device=device) ort_presents_decode = [] ort_outputs_decode = {"logits": ort_logits_decode} present_shape_decode = ( - BATCH_SIZE, - NUM_KV_HEADS, - PREFILL_LEN + DECODE_LEN, - HEAD_DIM, + batch_size, + num_kv_heads, + prefill_len + decode_len, + head_dim, ) - for i in range(NUM_LAYERS): + for i in range(num_layers): ort_present_k = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) ort_present_v = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) ort_outputs_decode[f"present.{i}.key"] = ort_present_k From 7ccc607767bddc2010cc56c9f984635317f3246d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Nov 2025 23:00:12 +0000 Subject: [PATCH 08/18] refine --- test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index 82f03133e0..ca36566de8 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -94,8 +94,7 @@ def ort_io_binding_helper( for name, tensor in input_tensors.items(): if not tensor.is_contiguous(): - print(f"Warning: Input tensor {name} is not contiguous. Making it contiguous.") - input_tensors[name] = tensor.contiguous() + raise RuntimeError(f"Input tensor {name} is not contiguous.") bind.bind_input( name, @@ -108,8 +107,7 @@ def ort_io_binding_helper( for name, tensor in output_tensors.items(): if not tensor.is_contiguous(): - print(f"Warning: Output tensor {name} is not contiguous. Making it contiguous.") - output_tensors[name] = tensor.contiguous() + raise RuntimeError(f"Output tensor {name} is not contiguous.") bind.bind_output( name, @@ -193,7 +191,7 @@ def test_parity( num_layers = config.num_hidden_layers num_kv_heads = config.num_key_value_heads head_dim = config.hidden_size // config.num_attention_heads - vocab_size = config.vocab_size # Get vocab size for output + vocab_size = config.vocab_size print("\n--- Model Parameters ---") print(f"Device: {device}") @@ -224,7 +222,7 @@ def test_parity( # Qwen2.5-VL uses 3D position IDs (temporal, height, width). # For text tokens, all three dimensions typically use the same sequence index. pos_ids_1d_prefill = torch.arange(prefill_len, device=device).expand(batch_size, -1) - position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1) + position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1).contiguous() attention_mask_prefill = torch.ones((batch_size, prefill_len), dtype=torch.int64, device=device) @@ -297,7 +295,7 @@ def test_parity( # Position IDs continue from prefill length pos_ids_1d_decode = torch.tensor([[prefill_len]], dtype=torch.int64, device=device) - position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1) + position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1).contiguous() attention_mask_decode = torch.ones((batch_size, prefill_len + decode_len), dtype=torch.int64, device=device) From 20891af7d75169d56e9894af03b0737c64e923c0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 21 Nov 2025 00:29:00 +0000 Subject: [PATCH 09/18] review feedback --- src/python/py/models/builder.py | 4 + src/python/py/models/builders/base.py | 7 +- src/python/py/models/builders/qwen.py | 254 +++++++++++++++----------- 3 files changed, 158 insertions(+), 107 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index b4b203618e..f440cf717b 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -304,6 +304,10 @@ def create_model( elif config.architectures[0] == "SmolLM3ForCausalLM": onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration": + text_config = config.text_config + for key in text_config: + if not hasattr(config, key): + setattr(config, key, getattr(text_config, key)) print( "WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default." ) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index b14c0755e9..1d98877e8b 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -466,7 +466,7 @@ def make_rope_init(self, config): "sections": config.rope_scaling["mrope_section"], # Sections for MRoPE } - def make_attention_init(self): + def is_gqa_supported(self) -> bool: valid_gqa_configurations = { ("cpu", ir.DataType.FLOAT), ("cuda", ir.DataType.FLOAT16), @@ -476,7 +476,10 @@ def make_attention_init(self): ("webgpu", ir.DataType.FLOAT), ("trt-rtx", ir.DataType.FLOAT16), } - if (self.ep, self.io_dtype) in valid_gqa_configurations: + return (self.ep, self.io_dtype) in valid_gqa_configurations + + def make_attention_init(self): + if self.is_gqa_supported(): # Change model settings for GroupQueryAttention self.attention_attrs["op_type"] = "GroupQueryAttention" print("GroupQueryAttention (GQA) is used in this model.") diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 107b851df7..e9a7a6fa86 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -import os import onnx_ir as ir import torch @@ -30,23 +29,6 @@ def make_attention_init(self): class Qwen25VLTextModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): - # We must extract the text_config for the text model's parameters - text_config_dict = config.text_config.to_dict() - - # Update the main config with text-specific parameters - # The base.Model class reads from the top-level config object - config.hidden_size = text_config_dict["hidden_size"] - config.intermediate_size = text_config_dict["intermediate_size"] - config.num_attention_heads = text_config_dict["num_attention_heads"] - config.num_hidden_layers = text_config_dict["num_hidden_layers"] - config.num_key_value_heads = text_config_dict["num_key_value_heads"] - config.rms_norm_eps = text_config_dict["rms_norm_eps"] - config.sliding_window = text_config_dict["sliding_window"] - config.rope_scaling = text_config_dict["rope_scaling"] - # Need this for attention_scaling calculation - if "original_max_position_embeddings" in text_config_dict: - config.original_max_position_embeddings = text_config_dict["original_max_position_embeddings"] - super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) # The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32. @@ -65,7 +47,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.layernorm_attrs["cast"]["skip_input"] = True self.layernorm_attrs["cast"]["output_0"] = True self.layernorm_attrs["cast"]["output_3"] = True - # + # Qwen2's RoPE *always* computes in float32. # We must replicate this behavior. print("Forcing RoPE computation to float32 for Qwen2.5-VL parity.") @@ -73,48 +55,27 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.attention_attrs["rope_cast"] = {} self.attention_attrs["rope_cast"]["use_fp32"] = True - # The base.Model.make_outputs_init() *always* casts logits to float32 - # if the io_dtype is bfloat16. This is to improve accuracy in general. - # - # PROBLEM: The HF model (Qwen2_5_VL) *does not* do this. It computes - # the lm_head MatMul in bfloat16 and returns bfloat16 logits. - # This causes the parity test (which compares bf16 vs fp32) to fail. - # - # SOLUTION: We must override the base model's decision and set the - # output logits type to match the io_dtype (bfloat16). - # - self.allow_bf16_logits = os.getenv("ALLOW_BF16_LOGITS") in ["1", "true", "True"] - if self.allow_bf16_logits and self.io_dtype == ir.DataType.BFLOAT16: - print("Fixing output logits precision. Setting output_types['logits'] to BFLOAT16 to match HF model.") - self.output_types["logits"] = ir.DataType.BFLOAT16 - - # Manually get the attention_scaling from the rope_config - # This replicates the logic from transformers.models.rope_utils._config_to_init_values + # Manually get the rope_attention_scaling from the rope_config + # Support rope types: 'default' or 'yarn' according to model cards in huggingface. Examples: + # "rope_scaling": {"type": "mrope", "mrope_section": [16, 24,24]} + # "rope_scaling": {"type": "yarn", "mrope_section": [ 16, 24, 24 ], "factor": 4, "original_max_position_embeddings": 32768 }} rope_type = "default" if config.rope_scaling and "type" in config.rope_scaling: # The config re-maps 'mrope' to 'default' if config.rope_scaling["type"] != "mrope": rope_type = config.rope_scaling["type"] + assert rope_type in ["default", "yarn"], f"Unsupported rope_type for this model: {rope_type}" + self.rope_attention_scaling = 1.0 if rope_type == "yarn": factor = config.rope_scaling.get("factor", 1.0) - self.rope_attrs["attention_scaling"] = config.rope_scaling.get( + self.rope_attention_scaling = config.rope_scaling.get( "attention_factor", (0.1 * torch.log(torch.tensor(factor)) + 1.0).item() ) - elif rope_type == "longrope": - factor = config.rope_scaling.get("factor", 1.0) - orig_max_pos = config.original_max_position_embeddings - self.rope_attrs["attention_scaling"] = config.rope_scaling.get( - "attention_factor", - torch.sqrt(1 + torch.log(torch.tensor(factor)) / torch.log(torch.tensor(orig_max_pos))).item(), - ) - else: - self.rope_attrs["attention_scaling"] = 1.0 # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False - # Your inheritance change fixed this, but this check is harmless and safe. if "position_ids" not in self.input_names: print("Re-adding 'position_ids' to self.input_names.") if "attention_mask" in self.input_names: @@ -137,16 +98,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})" ) - # Force GroupQueryAttention for fp32 cuda, - # as base.py's make_attention_init doesn't include this combo. - if self.ep == "cuda" and self.io_dtype == ir.DataType.FLOAT: - self.attention_attrs["op_type"] = "GroupQueryAttention" - print("Forcing GroupQueryAttention (GQA) for FP32 CUDA.") + # Force GroupQueryAttention since make_attention() below only implements GQA. + self.attention_attrs["op_type"] = "GroupQueryAttention" - if self.attention_attrs["op_type"] != "GroupQueryAttention": - raise ValueError( - f"Qwen2.5-VL requires GroupQueryAttention, but op_type is {self.attention_attrs['op_type']}. This may be due to an unsupported EP/precision combo." - ) + if not self.is_gqa_supported(): + print(f"Warning: {self.ep} does not support GQA for {self.io_dtype}, so GQA might fallback to CPU!") # Create and save the inv_freq tensor self.make_inv_freq_tensor() @@ -180,21 +136,47 @@ def make_inputs_and_outputs(self): super().make_inputs_and_outputs() def make_dynamic_rope_caches(self, layer_id, basename): - """ - Re-implements Qwen2_5_VLRotaryEmbedding.forward using ONNX ops. - Takes 3D position_ids and inv_freq and dynamically creates - the cos/sin caches. - """ + # Make nodes for the Dynamic RoPE Cache subgraph + # + # Re-implements Qwen2_5_VLRotaryEmbedding.forward using ONNX ops. + # Takes 3D position_ids and inv_freq and dynamically creates + # the cos/sin caches. + # + # inv_freq (H/2) position_ids (3, B, S) + # | | + # Unsqueeze Unsqueeze + # | | + # Expand Cast + # (3, B, H/2, 1) (3, B, 1, S) + # | | + # +--------------------------+---------------------------+ + # | + # MatMul + # (3, B, H/2, S) + # | + # Transpose + # (3, B, S, H/2) + # | + # Concat + # (3, B, S, H) + # | + # +-------------+-------------+ + # | | + # Cos Sin + # | | + # Mul Mul + # (apply scaling) (apply scaling) + # pos_ids_name = "position_ids" inv_freq_name = "model.inv_freq" head_dim_half = self.head_size // 2 # Get Batch Size from position_ids.shape[1] - shape_pos_ids_name = f"{basename}/shape_pos_ids" + shape_pos_ids_name = f"{basename}/pos_ids/Shape" shape_pos_ids_output = f"{shape_pos_ids_name}/output_0" self.make_shape(shape_pos_ids_name, pos_ids_name, [3]) - gather_batch_size_name = f"{basename}/gather_batch_size" + gather_batch_size_name = f"{basename}/pos_ids/Gather" gather_batch_size_output = f"{gather_batch_size_name}/output_0" self.make_gather( gather_batch_size_name, @@ -205,7 +187,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Expand inv_freq: [H/2] -> [1, 1, H/2, 1] - unsqueeze_1_name = f"{basename}/inv_freq_unsqueeze_1" + unsqueeze_1_name = f"{basename}/inv_freq/Unsqueeze" unsqueeze_1_output = f"{unsqueeze_1_name}/output_0" self.make_unsqueeze( unsqueeze_1_name, @@ -215,7 +197,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Create target shape for Expand: [3, B, H/2, 1] - concat_expand_shape_name = f"{basename}/concat_expand_shape" + concat_expand_shape_name = f"{basename}/expand_shape/Concat" concat_expand_shape_output = f"{concat_expand_shape_name}/output_0" self.make_concat( concat_expand_shape_name, @@ -229,7 +211,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): axis=0, ) - expand_name = f"{basename}/inv_freq_expand" + expand_name = f"{basename}/inv_freq/Expand" expand_output = f"{expand_name}/output_0" self.make_expand( expand_name, @@ -239,7 +221,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Expand position_ids: [3, B, S] -> [3, B, 1, S] - unsqueeze_2_name = f"{basename}/pos_ids_unsqueeze" + unsqueeze_2_name = f"{basename}/pos_ids/Unsqueeze" unsqueeze_2_output = f"{unsqueeze_2_name}/output_0" self.make_unsqueeze( unsqueeze_2_name, @@ -249,7 +231,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Cast position_ids to float - cast_name = f"{basename}/pos_ids_cast" + cast_name = f"{basename}/pos_ids/Cast" cast_output = f"{cast_name}/output_0" self.make_cast( cast_name, @@ -259,7 +241,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # MatMul: [3, B, H/2, 1] @ [3, B, 1, S] -> [3, B, H/2, S] - matmul_name = f"{basename}/freqs_matmul" + matmul_name = f"{basename}/freqs/MatMul" matmul_output = f"{matmul_name}/output_0" self.make_node("MatMul", [expand_output, cast_output], [matmul_output], name=matmul_name) self.make_value( @@ -269,7 +251,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Transpose: [3, B, H/2, S] -> [3, B, S, H/2] - transpose_name = f"{basename}/freqs_transpose" + transpose_name = f"{basename}/freqs/Transpose" transpose_output = f"{transpose_name}/output_0" self.make_transpose( transpose_name, @@ -280,7 +262,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Concat (freqs, freqs): [3, B, S, H/2] -> [3, B, S, H] - concat_name = f"{basename}/emb_concat" + concat_name = f"{basename}/Concat" concat_output = f"{concat_name}/output_0" self.make_concat( concat_name, @@ -291,7 +273,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): ) # Cos(emb) and Sin(emb) - cos_name = f"{basename}/cos" + cos_name = f"{basename}/Cos" cos_output = f"{cos_name}/output_0" self.make_node("Cos", [concat_output], [cos_output], name=cos_name) self.make_value( @@ -300,7 +282,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): [3, "batch_size", "sequence_length", self.head_size], ) - sin_name = f"{basename}/sin" + sin_name = f"{basename}/Sin" sin_output = f"{sin_name}/output_0" self.make_node("Sin", [concat_output], [sin_output], name=sin_name) self.make_value( @@ -309,15 +291,15 @@ def make_dynamic_rope_caches(self, layer_id, basename): [3, "batch_size", "sequence_length", self.head_size], ) - # Apply attention_scaling + # Apply scaling when rope type is "yarn". cos_final_output = cos_output sin_final_output = sin_output - scale = self.rope_attrs.get("attention_scaling", 1.0) # Get from rope_attrs + scale = self.rope_attention_scaling if scale != 1.0: scale_const_name = f"/model/constants/FLOAT/{scale}" - cos_mul_name = f"{basename}/cos_mul_scale" + cos_mul_name = f"{basename}/cos_scale/Mul" cos_final_output = f"{cos_mul_name}/output_0" self.make_node( "Mul", @@ -331,7 +313,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): [3, "batch_size", "sequence_length", self.head_size], ) - sin_mul_name = f"{basename}/sin_mul_scale" + sin_mul_name = f"{basename}/sin_scale/Mul" sin_final_output = f"{sin_mul_name}/output_0" self.make_node( "Mul", @@ -382,11 +364,40 @@ def rotate_half(self, x_name, x_shape, basename, compute_dtype): return concat_output def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): - """ - Re-implements apply_multimodal_rotary_pos_emb using ONNX ops. - Takes Q/K tensor and the dynamically generated 3D caches - and applies the rotation. - """ + # Make nodes for the MRoPE rotation subgraph + # + # Re-implements apply_multimodal_rotary_pos_emb using ONNX ops. + # Takes Q/K tensor and the dynamically generated 3D caches + # and applies the rotation. + # + # dyn_cos (3, B, S, H) dyn_sin (3, B, S, H) q_or_k (B, S, N*H) + # | | | + # Split Split Reshape + # (into 6 parts) (into 6 parts) | + # | | Transpose + # +-------+-------+ +-------+-------+ (B, N, S, H) + # | ...loop... | | ...loop... | | + # Gather(dim_idx) | Gather(dim_idx) | | + # | | | | | + # Unsqueeze | Unsqueeze | | + # | | | | | + # +-------+-------+ +-------+-------+ | + # | | | + # Concat Concat | + # (B, 1, S, H) (B, 1, S, H) | + # | | | + # +-----------+----------+ | + # | | + # (Mixed Precision Casts) | + # | | + # +-----------------------+-----------+ + # | + # (q * cos) + (rotate_half(q) * sin) + # | + # Transpose + # | + # Reshape + # # --- Handle precision for RoPE --- # Check if we need to force float32 computation @@ -399,7 +410,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Create a Constant node for mrope_splits # This holds the correct splits, e.g., [16, 24, 24, 16, 24, 24] - mrope_splits_node_name = f"{basename}/mrope_splits_node" + mrope_splits_node_name = f"{basename}/mrope_splits/Constant" mrope_splits_output_name = f"{basename}/mrope_splits" mrope_splits_tensor = ir.tensor( torch.tensor(self.mrope_splits, dtype=torch.int64), @@ -418,7 +429,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Caches (dyn_cos, dyn_sin) are already in float32 num_splits = len(self.mrope_splits) - cos_split_name = f"{basename}/cos_split" + cos_split_name = f"{basename}/cos/Split" cos_split_outputs = [f"{cos_split_name}/output_{i}" for i in range(num_splits)] self.make_node( "Split", @@ -428,7 +439,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn axis=-1, ) - sin_split_name = f"{basename}/sin_split" + sin_split_name = f"{basename}/sin/Split" sin_split_outputs = [f"{sin_split_name}/output_{i}" for i in range(num_splits)] self.make_node( "Split", @@ -447,7 +458,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Gather from dim 0 of the split cache chunk # input is [3, B, S, H_chunk], indices is [0, 1, or 2] - gather_cos_name = f"{basename}/cos_gather_{i}" + gather_cos_name = f"{basename}/cos_{i}/Gather" gather_cos_output = f"{gather_cos_name}/output_0" self.make_node( "Gather", @@ -462,7 +473,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn [1, "batch_size", "sequence_length", dim_chunk], ) # Shape [1, B, S, H_chunk] - gather_sin_name = f"{basename}/sin_gather_{i}" + gather_sin_name = f"{basename}/sin_{i}/Gather" gather_sin_output = f"{gather_sin_name}/output_0" self.make_node( "Gather", @@ -478,7 +489,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) # Shape [1, B, S, H_chunk] # FIX: Squeeze the gathered cache to [B, S, H_chunk] - squeeze_cos_name = f"{basename}/cos_squeeze_{i}" + squeeze_cos_name = f"{basename}/cos_{i}/Squeeze" squeeze_cos_output = f"{squeeze_cos_name}/output_0" self.make_squeeze( squeeze_cos_name, @@ -487,7 +498,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ["batch_size", "sequence_length", dim_chunk], ) - squeeze_sin_name = f"{basename}/sin_squeeze_{i}" + squeeze_sin_name = f"{basename}/sin_{i}/Squeeze" squeeze_sin_output = f"{squeeze_sin_name}/output_0" self.make_squeeze( squeeze_sin_name, @@ -497,7 +508,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) # Unsqueeze to add the NumHeads dim: [B, 1, S, H_chunk] - unsqueeze_cos_name = f"{basename}/cos_unsqueeze_{i}" + unsqueeze_cos_name = f"{basename}/cos_{i}/Unsqueeze" unsqueeze_cos_output = f"{unsqueeze_cos_name}/output_0" self.make_unsqueeze( unsqueeze_cos_name, @@ -507,7 +518,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) cos_reordered.append(unsqueeze_cos_output) - unsqueeze_sin_name = f"{basename}/sin_unsqueeze_{i}" + unsqueeze_sin_name = f"{basename}/sin_{i}/Unsqueeze" unsqueeze_sin_output = f"{unsqueeze_sin_name}/output_0" self.make_unsqueeze( unsqueeze_sin_name, @@ -518,7 +529,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn sin_reordered.append(unsqueeze_sin_output) # Concat re-ordered chunks back to [B, 1, S, H] - final_cos_concat_name = f"{basename}/cos_final_concat" + final_cos_concat_name = f"{basename}/cos_final/Concat" final_cos_concat_output = f"{final_cos_concat_name}/output_0" self.make_concat( final_cos_concat_name, @@ -528,7 +539,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn axis=-1, ) - final_sin_concat_name = f"{basename}/sin_final_concat" + final_sin_concat_name = f"{basename}/sin_final/Concat" final_sin_concat_output = f"{final_sin_concat_name}/output_0" self.make_concat( final_sin_concat_name, @@ -540,8 +551,8 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Caches (final_cos_concat_output, final_sin_concat_output) are now in float32 - # Reshape input Q/K: [B, S, N*H] -> [B, N, S, H] - reshape_1_name = f"{basename}/q_or_k_reshape_1" + # Reshape input Q/K: [B, S, N*H] -> [B, S, N, H] + reshape_1_name = f"{basename}/q_or_k_bsd_to_bsnh/Reshape" reshape_1_output = f"{reshape_1_name}/output_0" reshape_1_target_shape_onnx = f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]" reshape_1_target_shape_ort = [ @@ -558,7 +569,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) # Transpose Q/K: [B, S, N, H] -> [B, N, S, H] - transpose_1_name = f"{basename}/q_or_k_transpose_1" + transpose_1_name = f"{basename}/q_or_k_bsnh_to_bnsh/Transpose" transpose_1_output = f"{transpose_1_name}/output_0" transpose_1_target_shape = [ "batch_size", @@ -581,7 +592,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn if force_fp32 and self.io_dtype != ir.DataType.FLOAT: # Cast Q/K (self.io_dtype) up to float32 - q_or_k_cast_name = f"{basename}/q_or_k_cast_fp32" + q_or_k_cast_name = f"{basename}/q_or_k/Cast" q_or_k_cast_output = f"{q_or_k_cast_name}/output_0" self.make_cast( q_or_k_cast_name, @@ -592,7 +603,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn q_or_k_compute_input = q_or_k_cast_output elif not force_fp32 and self.io_dtype != ir.DataType.FLOAT: # Cast Caches (float32) down to self.io_dtype - cos_cache_cast_name = f"{basename}/cos_final_cast" + cos_cache_cast_name = f"{basename}/cos_final/Cast" cos_cache_cast_output = f"{cos_cache_cast_name}/output_0" self.make_cast( cos_cache_cast_name, @@ -602,7 +613,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) cos_cache_compute_input = cos_cache_cast_output - sin_cache_cast_name = f"{basename}/sin_final_cast" + sin_cache_cast_name = f"{basename}/sin_final/Cast" sin_cache_cast_output = f"{sin_cache_cast_name}/output_0" self.make_cast( sin_cache_cast_name, @@ -615,7 +626,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Apply rotation: (q * cos) + (rotate_half(q) * sin) # 1. (q * cos) - mul_1_name = f"{basename}/mul_1" + mul_1_name = f"{basename}/Mul_1" mul_1_output = f"{mul_1_name}/output_0" self.make_mul( mul_1_name, @@ -628,7 +639,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn rotated_half_q_name = self.rotate_half(q_or_k_compute_input, transpose_1_target_shape, basename, compute_dtype) # 3. (rotate_half(q) * sin) - mul_2_name = f"{basename}/mul_2" + mul_2_name = f"{basename}/Mul_2" mul_2_output = f"{mul_2_name}/output_0" self.make_mul( mul_2_name, @@ -638,7 +649,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) # 4. (q * cos) + (rotate_half(q) * sin) - add_name = f"{basename}/add" + add_name = f"{basename}/add/Add" add_output = f"{add_name}/output_0" self.make_add( add_name, @@ -652,13 +663,13 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn add_output_final = add_output if force_fp32 and self.io_dtype != ir.DataType.FLOAT: # Cast result back down to self.io_dtype - add_cast_name = f"{basename}/add_cast_output" + add_cast_name = f"{basename}/add/Cast" add_cast_output = f"{add_cast_name}/output_0" self.make_cast(add_cast_name, add_output, output_dtype, transpose_1_target_shape) add_output_final = add_cast_output # Transpose back: [B, N, S, H] -> [B, S, N, H] - transpose_2_name = f"{basename}/q_or_k_transpose_2" + transpose_2_name = f"{basename}/q_or_k_bnsh_to_bsnh/Transpose" transpose_2_output = f"{transpose_2_name}/output_0" self.make_transpose( transpose_2_name, @@ -669,7 +680,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn ) # Reshape back: [B, S, N, H] -> [B, S, N*H] - reshape_2_name = f"{basename}/q_or_k_reshape_2" + reshape_2_name = f"{basename}/q_or_k_bsnh_to_bsd/Reshape" reshape_2_output = f"{reshape_2_name}/output_0" self.make_reshape( reshape_2_name, @@ -684,6 +695,39 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn return reshape_2_output def make_attention(self, layer_id, attention, root_input, **kwargs): + # Make nodes for the Attention subgraph (with MRoPE) + # + # root_input + # / | \ + # / | \ + # Q_MatMul K_MatMul V_MatMul + # | | | + # Q_Add K_Add V_Add + # | | | + # | | +-----------------+ + # | | | + # (make_dynamic_rope_caches) | + # | | + # +-----+-----+ | + # | | | + # dyn_cos dyn_sin | + # | | | + # v v | + # (apply_mrope_rotation for Q) | + # | | + # Q_Rot | + # | (apply_mrope_rotation for K) | + # | | | + # | K_Rot | + # | | | + # +--------+--------+ | + # | | + # GroupQueryAttention <--------------+ + # | + # O_MatMul + # | + # O_Add + # 1. Unpack QKV if necessary (e.g. qkv_proj) super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) From e5359b3881a1aec4595dff77741333fcd667e36f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 21 Nov 2025 00:39:04 +0000 Subject: [PATCH 10/18] undo phi.py --- src/python/py/models/builders/phi.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/python/py/models/builders/phi.py b/src/python/py/models/builders/phi.py index c29208ab4b..9aa309441d 100644 --- a/src/python/py/models/builders/phi.py +++ b/src/python/py/models/builders/phi.py @@ -121,26 +121,13 @@ def make_position_ids_reformatting(self): ] self.make_greater_or_equal(greater_or_equal_name, greater_or_equal_inputs, shape=[]) cast_name = f"{basename}/Cast" - self.make_cast( - cast_name, - f"{greater_or_equal_name}/output_0", - dtype=compute_dtype, - shape=None, - ) + self.make_cast(cast_name, f"{greater_or_equal_name}/output_0", dtype=compute_dtype, shape=None) mul_name = f"{basename}/Mul" - mul_inputs = [ - f"{cast_name}/output_0", - f"/model/constants/{compute_str_dtype}/{self.original_context_length}", - ] + mul_inputs = [f"{cast_name}/output_0", f"/model/constants/{compute_str_dtype}/{self.original_context_length}"] self.make_mul(mul_name, mul_inputs, dtype=compute_dtype, shape=None) add_1_name = f"{basename}/Add_1" add_1_inputs = [f"{mul_name}/output_0", input_tensor] - self.make_add( - add_1_name, - add_1_inputs, - dtype=compute_dtype, - shape=["batch_size", "sequence_length"], - ) + self.make_add(add_1_name, add_1_inputs, dtype=compute_dtype, shape=["batch_size", "sequence_length"]) # Cast back to int64 for WebGPU to maintain compatibility result_name = add_1_name @@ -158,13 +145,7 @@ def make_position_ids_reformatting(self): def make_attention(self, layer_id, attention, root_input, **kwargs): if self.position_ids_name is not None: - super().make_attention( - layer_id, - attention, - root_input, - position_ids=self.position_ids_name, - **kwargs, - ) + super().make_attention(layer_id, attention, root_input, position_ids=self.position_ids_name, **kwargs) else: super().make_attention(layer_id, attention, root_input, **kwargs) From 9785d0a31f2182499429b76303fb259f4a5104d7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 21 Nov 2025 00:41:17 +0000 Subject: [PATCH 11/18] update test script --- test/python/models/qwen_2.5_vl/run.sh | 11 ++++++++--- test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py | 14 ++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/test/python/models/qwen_2.5_vl/run.sh b/test/python/models/qwen_2.5_vl/run.sh index 9eee36e151..6da708b743 100644 --- a/test/python/models/qwen_2.5_vl/run.sh +++ b/test/python/models/qwen_2.5_vl/run.sh @@ -19,9 +19,10 @@ fi # 2. Define variables based on input PRECISION=$1 -OUTPUT_DIR="./qwen_${PRECISION}" +TEST_DIR="$(CDPATH= cd -- "$(dirname -- "$0")" && pwd)" +OUTPUT_DIR="${TEST_DIR}/qwen_${PRECISION}" ONNX_MODEL_PATH="${OUTPUT_DIR}/model.onnx" -CACHE_DIR="./cache" +CACHE_DIR="${TEST_DIR}/cache" HF_MODEL="Qwen/Qwen2.5-VL-3B-Instruct" # Set the --bf16 or --fp16 flag for the test script @@ -38,10 +39,13 @@ if [ "$2" == "-f" ] && [ -d "${OUTPUT_DIR}" ]; then rm -rf "${OUTPUT_DIR}" fi +BUILDER_DIR="$(cd ../../../../src/python/py/models && pwd)" + # 4. Run the builder script if output directory does not exist. if ! [ -d "${OUTPUT_DIR}" ]; then echo "--- Building ${PRECISION} model ---" - python -m onnxruntime_genai.models.builder \ + cd "${BUILDER_DIR}" + python builder.py \ -m ${HF_MODEL} \ -p ${PRECISION} \ -o ${OUTPUT_DIR} \ @@ -50,6 +54,7 @@ if ! [ -d "${OUTPUT_DIR}" ]; then fi # 5. Run the parity test +cd "${TEST_DIR}" echo "--- Testing ${PRECISION} model parity ---" python test_qwen_2.5_vl.py \ --hf_model ${HF_MODEL} \ diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py index ca36566de8..3613434766 100644 --- a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- import argparse -import os import numpy as np import onnxruntime as ort @@ -156,15 +155,10 @@ def test_parity( # Standard FP32 tolerances rtol, atol = 1e-1, 1e-1 - allow_bf16_logits = os.getenv("ALLOW_BF16_LOGITS") in ["1", "true", "True"] - - if allow_bf16_logits: - logits_dtype = torch_dtype - else: - # The builder script (base.Model) upcasts logits to float32 - # ONLY when the io_dtype is bfloat16. - # For FP16 or FP32, it keeps the original dtype. - logits_dtype = torch.float32 if use_bf16 else torch_dtype + # The builder script (base.Model) upcasts logits to float32 + # ONLY when the io_dtype is bfloat16. + # For FP16 or FP32, it keeps the original dtype. + logits_dtype = torch.float32 if use_bf16 else torch_dtype print(f"Allocating ONNX logits output buffer with dtype: {logits_dtype}") From d438559dc9fc057c8e52e5e9be945118108d55b3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 21 Nov 2025 18:33:04 +0000 Subject: [PATCH 12/18] update comment --- src/python/py/models/builders/qwen.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index e9a7a6fa86..46de649a3b 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -33,9 +33,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32. # By inheriting from `base.Model`, all `layernorm_attrs["cast"]` flags - # are `False`. This causes two problems: - # 1. Parity Error (FP32 model): The 47% mismatch you saw. - # 2. Type Mismatch Error (BF16 model): The `(float)` vs `(bfloat16)` error. + # are `False`. This causes parity loss and type mismatch error. # # SOLUTION: Manually set all `cast` flags to `True`. This forces the # builder to cast bf16 inputs -> fp32, compute LN, and cast fp32 @@ -330,10 +328,24 @@ def make_dynamic_rope_caches(self, layer_id, basename): return cos_final_output, sin_final_output def rotate_half(self, x_name, x_shape, basename, compute_dtype): - """ - Builds ONNX nodes for rotate_half(x) - x_shape is [B, N, S, H] - """ + # Make nodes for rotate_half subgraph + # + # x (B, N, S, H) + # | + # Split + # / \ + # / \ + # x1 (..., H/2) x2 (..., H/2) + # | | + # | Neg + # | | + # | -x2 + # \ / + # \ / + # Concat + # | + # output (..., H) + # Split: [B, N, S, H] -> [B, N, S, H/2], [B, N, S, H/2] split_name = f"{basename}/rotate_half/Split" split_output_0 = f"{split_name}/output_0" From f39a188d411b6a6957a973bf016318978ac6d2ba Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 24 Nov 2025 19:29:09 +0000 Subject: [PATCH 13/18] remove yarn support --- src/python/py/models/builders/qwen.py | 57 ++------------------------- 1 file changed, 4 insertions(+), 53 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 46de649a3b..5f841d0496 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -53,23 +53,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.attention_attrs["rope_cast"] = {} self.attention_attrs["rope_cast"]["use_fp32"] = True - # Manually get the rope_attention_scaling from the rope_config - # Support rope types: 'default' or 'yarn' according to model cards in huggingface. Examples: - # "rope_scaling": {"type": "mrope", "mrope_section": [16, 24,24]} - # "rope_scaling": {"type": "yarn", "mrope_section": [ 16, 24, 24 ], "factor": 4, "original_max_position_embeddings": 32768 }} - rope_type = "default" + # Check rope type since huggingface model supports yarn but that is not recommended as mentioned in model card. Example: + # "rope_scaling": {"type": "mrope", "mrope_section": [16, 24,24]} if config.rope_scaling and "type" in config.rope_scaling: - # The config re-maps 'mrope' to 'default' - if config.rope_scaling["type"] != "mrope": - rope_type = config.rope_scaling["type"] - assert rope_type in ["default", "yarn"], f"Unsupported rope_type for this model: {rope_type}" - - self.rope_attention_scaling = 1.0 - if rope_type == "yarn": - factor = config.rope_scaling.get("factor", 1.0) - self.rope_attention_scaling = config.rope_scaling.get( - "attention_factor", (0.1 * torch.log(torch.tensor(factor)) + 1.0).item() - ) + assert config.rope_scaling["type"] == "mrope" # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False @@ -289,43 +276,7 @@ def make_dynamic_rope_caches(self, layer_id, basename): [3, "batch_size", "sequence_length", self.head_size], ) - # Apply scaling when rope type is "yarn". - cos_final_output = cos_output - sin_final_output = sin_output - scale = self.rope_attention_scaling - - if scale != 1.0: - scale_const_name = f"/model/constants/FLOAT/{scale}" - - cos_mul_name = f"{basename}/cos_scale/Mul" - cos_final_output = f"{cos_mul_name}/output_0" - self.make_node( - "Mul", - [cos_output, scale_const_name], - [cos_final_output], - name=cos_mul_name, - ) - self.make_value( - cos_final_output, - ir.DataType.FLOAT, - [3, "batch_size", "sequence_length", self.head_size], - ) - - sin_mul_name = f"{basename}/sin_scale/Mul" - sin_final_output = f"{sin_mul_name}/output_0" - self.make_node( - "Mul", - [sin_output, scale_const_name], - [sin_final_output], - name=sin_mul_name, - ) - self.make_value( - sin_final_output, - ir.DataType.FLOAT, - [3, "batch_size", "sequence_length", self.head_size], - ) - - return cos_final_output, sin_final_output + return cos_output, sin_output def rotate_half(self, x_name, x_shape, basename, compute_dtype): # Make nodes for rotate_half subgraph From f97c2995921337699b5a389fe88a3b7ee1ac0754 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 24 Nov 2025 20:04:43 +0000 Subject: [PATCH 14/18] simplify mrope subgraph --- src/python/py/models/builders/qwen.py | 557 ++++++++++---------------- 1 file changed, 206 insertions(+), 351 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 5f841d0496..8e67948e05 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -56,7 +56,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Check rope type since huggingface model supports yarn but that is not recommended as mentioned in model card. Example: # "rope_scaling": {"type": "mrope", "mrope_section": [16, 24,24]} if config.rope_scaling and "type" in config.rope_scaling: - assert config.rope_scaling["type"] == "mrope" + assert config.rope_scaling["type"] in ["mrope", "default"] # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False @@ -278,384 +278,239 @@ def make_dynamic_rope_caches(self, layer_id, basename): return cos_output, sin_output - def rotate_half(self, x_name, x_shape, basename, compute_dtype): - # Make nodes for rotate_half subgraph - # - # x (B, N, S, H) - # | - # Split - # / \ - # / \ - # x1 (..., H/2) x2 (..., H/2) - # | | - # | Neg - # | | - # | -x2 - # \ / - # \ / - # Concat - # | - # output (..., H) - - # Split: [B, N, S, H] -> [B, N, S, H/2], [B, N, S, H/2] - split_name = f"{basename}/rotate_half/Split" - split_output_0 = f"{split_name}/output_0" - split_output_1 = f"{split_name}/output_1" - self.make_node( - "Split", - [x_name], - [split_output_0, split_output_1], - name=split_name, - axis=-1, - num_outputs=2, - ) - half_shape = [*x_shape[:-1], x_shape[-1] // 2] - self.make_value(split_output_0, compute_dtype, half_shape) - self.make_value(split_output_1, compute_dtype, half_shape) - - # Negate x2 - neg_name = f"{basename}/rotate_half/Neg" - neg_output = f"{neg_name}/output_0" - self.make_node("Neg", [split_output_1], [neg_output], name=neg_name) - self.make_value(neg_output, compute_dtype, half_shape) - - # Concat (-x2, x1) - concat_name = f"{basename}/rotate_half/Concat" - concat_output = f"{concat_name}/output_0" - self.make_concat(concat_name, [neg_output, split_output_0], compute_dtype, x_shape, axis=-1) - - return concat_output - - def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): - # Make nodes for the MRoPE rotation subgraph - # - # Re-implements apply_multimodal_rotary_pos_emb using ONNX ops. - # Takes Q/K tensor and the dynamically generated 3D caches - # and applies the rotation. - # - # dyn_cos (3, B, S, H) dyn_sin (3, B, S, H) q_or_k (B, S, N*H) - # | | | - # Split Split Reshape - # (into 6 parts) (into 6 parts) | - # | | Transpose - # +-------+-------+ +-------+-------+ (B, N, S, H) - # | ...loop... | | ...loop... | | - # Gather(dim_idx) | Gather(dim_idx) | | - # | | | | | - # Unsqueeze | Unsqueeze | | - # | | | | | - # +-------+-------+ +-------+-------+ | - # | | | - # Concat Concat | - # (B, 1, S, H) (B, 1, S, H) | - # | | | - # +-----------+----------+ | - # | | - # (Mixed Precision Casts) | - # | | - # +-----------------------+-----------+ - # | - # (q * cos) + (rotate_half(q) * sin) - # | - # Transpose - # | - # Reshape - # - - # --- Handle precision for RoPE --- - # Check if we need to force float32 computation - force_fp32 = self.attention_attrs.get("rope_cast", {}).get("use_fp32", False) - - # Set compute_dtype (precision for math) and output_dtype (final precision) - compute_dtype = ir.DataType.FLOAT if force_fp32 else self.io_dtype - output_dtype = self.io_dtype - # -------------------------------- - - # Create a Constant node for mrope_splits - # This holds the correct splits, e.g., [16, 24, 24, 16, 24, 24] - mrope_splits_node_name = f"{basename}/mrope_splits/Constant" - mrope_splits_output_name = f"{basename}/mrope_splits" - mrope_splits_tensor = ir.tensor( - torch.tensor(self.mrope_splits, dtype=torch.int64), - name=mrope_splits_output_name, - ) - self.make_node( - "Constant", - inputs=[], - outputs=[mrope_splits_output_name], - name=mrope_splits_node_name, - value=mrope_splits_tensor, - ) - self.make_value(mrope_splits_output_name, ir.DataType.INT64, [len(self.mrope_splits)]) - - # Split the dynamic caches [3, B, S, H] into 6 chunks on axis -1 - # Caches (dyn_cos, dyn_sin) are already in float32 - num_splits = len(self.mrope_splits) - - cos_split_name = f"{basename}/cos/Split" - cos_split_outputs = [f"{cos_split_name}/output_{i}" for i in range(num_splits)] - self.make_node( - "Split", - [dyn_cos, mrope_splits_output_name], - cos_split_outputs, - name=cos_split_name, - axis=-1, - ) - - sin_split_name = f"{basename}/sin/Split" - sin_split_outputs = [f"{sin_split_name}/output_{i}" for i in range(num_splits)] - self.make_node( - "Split", - [dyn_sin, mrope_splits_output_name], - sin_split_outputs, - name=sin_split_name, - axis=-1, - ) + def make_mrope_flattened_caches(self, layer_id, dyn_cos, dyn_sin): + """ + Converts the 3D MRoPE caches [3, B, S, H] into flattened, interleaved caches [B*S, H/2] + suitable for the RotaryEmbedding operator. + The logic is: + 1. Slice dynamic caches to H/2. + 2. Split into 3 chunks based on mrope_sections (e.g. 16, 24, 24). + 3. Gather Temporal(0), Height(1), Width(2) specific slices for each chunk. + 4. Concat back to H/2. + 5. Flatten to [B*S, H/2]. + """ + basename = f"/model/layers.{layer_id}/attn/mrope_flattened_cache" + + def process_cache(input_name, name_suffix): + # 1. Slice to H/2: [3, B, S, H] -> [3, B, S, H/2] + slice_name = f"{basename}/{name_suffix}/Slice_Half" + slice_output = f"{slice_name}/output_0" + self.make_slice( + slice_name, + [ + input_name, + "/model/constants/INT64/[0]", + f"/model/constants/INT64/[{self.head_size // 2}]", + "/model/constants/INT64/[-1]", + ], + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size // 2], + ) - # Re-order the caches: [T, H, W, T, H, W] - cos_reordered = [] - sin_reordered = [] - for i in range(num_splits): - dim_chunk = self.mrope_splits[i] - cache_dim_to_use = i % 3 # 0 for T, 1 for H, 2 for W - - # Gather from dim 0 of the split cache chunk - # input is [3, B, S, H_chunk], indices is [0, 1, or 2] - gather_cos_name = f"{basename}/cos_{i}/Gather" - gather_cos_output = f"{gather_cos_name}/output_0" + # Create a Constant node for mrope_sections: [16, 24, 24] + sections_name = f"{basename}/mrope_sections/Constant" + sections_output = f"{basename}/mrope_sections" self.make_node( - "Gather", - [cos_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], - [gather_cos_output], - name=gather_cos_name, - axis=0, + "Constant", + [], + [sections_output], + name=sections_name, + value=ir.tensor(torch.tensor(self.mrope_sections, dtype=torch.int64), name=sections_output), ) - self.make_value( - gather_cos_output, - ir.DataType.FLOAT, - [1, "batch_size", "sequence_length", dim_chunk], - ) # Shape [1, B, S, H_chunk] + self.make_value(sections_output, ir.DataType.INT64, [3]) - gather_sin_name = f"{basename}/sin_{i}/Gather" - gather_sin_output = f"{gather_sin_name}/output_0" + # 2. Split: [3, B, S, H/2] -> 3 * [3, B, S, section_dim] + split_name = f"{basename}/{name_suffix}/Split" + split_outputs = [f"{split_name}/output_{i}" for i in range(3)] self.make_node( - "Gather", - [sin_split_outputs[i], f"/model/constants/INT64/[{cache_dim_to_use}]"], - [gather_sin_output], - name=gather_sin_name, - axis=0, - ) - self.make_value( - gather_sin_output, - ir.DataType.FLOAT, - [1, "batch_size", "sequence_length", dim_chunk], - ) # Shape [1, B, S, H_chunk] - - # FIX: Squeeze the gathered cache to [B, S, H_chunk] - squeeze_cos_name = f"{basename}/cos_{i}/Squeeze" - squeeze_cos_output = f"{squeeze_cos_name}/output_0" - self.make_squeeze( - squeeze_cos_name, - [gather_cos_output, "/model/constants/INT64/[0]"], - ir.DataType.FLOAT, - ["batch_size", "sequence_length", dim_chunk], + "Split", + [slice_output, sections_output], + split_outputs, + name=split_name, + axis=-1, ) - squeeze_sin_name = f"{basename}/sin_{i}/Squeeze" - squeeze_sin_output = f"{squeeze_sin_name}/output_0" - self.make_squeeze( - squeeze_sin_name, - [gather_sin_output, "/model/constants/INT64/[0]"], + # 3. Gather + Squeeze: Reorder T, H, W + gathered_chunks = [] + for i in range(3): + # Chunk 0->T(0), Chunk 1->H(1), Chunk 2->W(2) + gather_name = f"{basename}/{name_suffix}/chunk_{i}/Gather" + gather_output = f"{gather_name}/output_0" + self.make_node( + "Gather", + [split_outputs[i], f"/model/constants/INT64/[{i}]"], + [gather_output], + name=gather_name, + axis=0, + ) + # Gather output is [1, B, S, dim] + + squeeze_name = f"{basename}/{name_suffix}/chunk_{i}/Squeeze" + squeeze_output = f"{squeeze_name}/output_0" + self.make_squeeze( + squeeze_name, + [gather_output, "/model/constants/INT64/[0]"], + ir.DataType.FLOAT, + ["batch_size", "sequence_length", self.mrope_sections[i]], + ) + gathered_chunks.append(squeeze_output) + + # 4. Concat: -> [B, S, H/2] + concat_name = f"{basename}/{name_suffix}/Concat" + concat_output = f"{concat_name}/output_0" + self.make_concat( + concat_name, + gathered_chunks, ir.DataType.FLOAT, - ["batch_size", "sequence_length", dim_chunk], + ["batch_size", "sequence_length", self.head_size // 2], + axis=-1, ) - # Unsqueeze to add the NumHeads dim: [B, 1, S, H_chunk] - unsqueeze_cos_name = f"{basename}/cos_{i}/Unsqueeze" - unsqueeze_cos_output = f"{unsqueeze_cos_name}/output_0" - self.make_unsqueeze( - unsqueeze_cos_name, - [squeeze_cos_output, "/model/constants/INT64/[1]"], + # 5. Flatten: -> [B*S, H/2] + reshape_name = f"{basename}/{name_suffix}_flat/Reshape" + reshape_output = f"{reshape_name}/output_0" + self.make_reshape( + reshape_name, + [concat_output, f"/model/constants/INT64/[-1, {self.head_size // 2}]"], ir.DataType.FLOAT, - ["batch_size", 1, "sequence_length", dim_chunk], + ["total_token_count", self.head_size // 2], ) - cos_reordered.append(unsqueeze_cos_output) + return reshape_output - unsqueeze_sin_name = f"{basename}/sin_{i}/Unsqueeze" - unsqueeze_sin_output = f"{unsqueeze_sin_name}/output_0" - self.make_unsqueeze( - unsqueeze_sin_name, - [squeeze_sin_output, "/model/constants/INT64/[1]"], - ir.DataType.FLOAT, - ["batch_size", 1, "sequence_length", dim_chunk], - ) - sin_reordered.append(unsqueeze_sin_output) + flat_cos = process_cache(dyn_cos, "cos") + flat_sin = process_cache(dyn_sin, "sin") - # Concat re-ordered chunks back to [B, 1, S, H] - final_cos_concat_name = f"{basename}/cos_final/Concat" - final_cos_concat_output = f"{final_cos_concat_name}/output_0" - self.make_concat( - final_cos_concat_name, - cos_reordered, - ir.DataType.FLOAT, - ["batch_size", 1, "sequence_length", self.head_size], - axis=-1, - ) - - final_sin_concat_name = f"{basename}/sin_final/Concat" - final_sin_concat_output = f"{final_sin_concat_name}/output_0" - self.make_concat( - final_sin_concat_name, - sin_reordered, - ir.DataType.FLOAT, - ["batch_size", 1, "sequence_length", self.head_size], - axis=-1, - ) + return flat_cos, flat_sin - # Caches (final_cos_concat_output, final_sin_concat_output) are now in float32 - - # Reshape input Q/K: [B, S, N*H] -> [B, S, N, H] - reshape_1_name = f"{basename}/q_or_k_bsd_to_bsnh/Reshape" - reshape_1_output = f"{reshape_1_name}/output_0" - reshape_1_target_shape_onnx = f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]" - reshape_1_target_shape_ort = [ - "batch_size", - "sequence_length", - num_heads, - self.head_size, - ] + def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): + # Use the optimized RotaryEmbedding operator with a single node. + # + # 1. Prepare flattened MRoPE caches [B*S, H/2] + # This slices, splits, and re-assembles the 3D dynamic caches into the correct per-token layout. + flat_cos, flat_sin = self.make_mrope_flattened_caches(layer_id, dyn_cos, dyn_sin) + + # 2. Prepare position_ids [B, S] (values 0 to B*S - 1) + # RotaryEmbedding will use these indices to access the flattened cache. + # Get B*S from q_or_k shape. q_or_k is [B, S, N*H]. + shape_node = f"{basename}/Shape" + self.make_shape(shape_node, q_or_k_path, [3]) + + # Extract B and S + batch_size_node = f"{basename}/BatchSize/Gather" + batch_size_out = f"{batch_size_node}/output_0" + self.make_gather(batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, + [], 0) + + seq_len_node = f"{basename}/SeqLen/Gather" + seq_len_out = f"{seq_len_node}/output_0" + self.make_gather(seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], + 0) + + # Calculate Total Tokens = B * S + mul_len_node = f"{basename}/TotalLen/Mul" + mul_len_out = f"{mul_len_node}/output_0" + self.make_node("Mul", [batch_size_out, seq_len_out], [mul_len_out], name=mul_len_node) + self.make_value(mul_len_out, ir.DataType.INT64, []) + + # Range(0, TotalTokens) + range_node = f"{basename}/Range" + range_out = f"{range_node}/output_0" + self.make_node("Range", ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"], [range_out], + name=range_node) + self.make_value(range_out, ir.DataType.INT64, ["total_token_count"]) + + # Slice Position IDs shape from input shape (take first 2 dims) + slice_shape_node = f"{basename}/SliceShape" + slice_shape_out = f"{slice_shape_node}/output_0" + self.make_slice(slice_shape_node, + [f"{shape_node}/output_0", "/model/constants/INT64/[0]", "/model/constants/INT64/[2]", + "/model/constants/INT64/[0]"], ir.DataType.INT64, [2]) + + # Reshape Range output to [B, S] + pos_ids_reshape_node = f"{basename}/PosIds/Reshape" + pos_ids_out = f"{pos_ids_reshape_node}/output_0" + self.make_reshape(pos_ids_reshape_node, [range_out, slice_shape_out], ir.DataType.INT64, + ["batch_size", "sequence_length"]) + + # 3. Prepare Q/K input [B, N, S, H] + # Input is [B, S, N*H]. Reshape -> [B, S, N, H] -> Transpose -> [B, N, S, H] + reshape_in_node = f"{basename}/Input/Reshape" + reshape_in_out = f"{reshape_in_node}/output_0" self.make_reshape( - reshape_1_name, - [q_or_k_path, reshape_1_target_shape_onnx], + reshape_in_node, + [q_or_k_path, f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]"], self.io_dtype, - reshape_1_target_shape_ort, + ["batch_size", "sequence_length", num_heads, self.head_size], ) - # Transpose Q/K: [B, S, N, H] -> [B, N, S, H] - transpose_1_name = f"{basename}/q_or_k_bsnh_to_bnsh/Transpose" - transpose_1_output = f"{transpose_1_name}/output_0" - transpose_1_target_shape = [ - "batch_size", - num_heads, - "sequence_length", - self.head_size, - ] - self.make_transpose( - transpose_1_name, - reshape_1_output, - self.io_dtype, - transpose_1_target_shape, - perm=[0, 2, 1, 3], - ) + transpose_in_node = f"{basename}/Input/Transpose" + transpose_in_out = f"{transpose_in_node}/output_0" + target_shape_bnsh = ["batch_size", num_heads, "sequence_length", self.head_size] + self.make_transpose(transpose_in_node, reshape_in_out, self.io_dtype, target_shape_bnsh, [0, 2, 1, 3]) - # --- Start RoPE computation --- - q_or_k_compute_input = transpose_1_output - cos_cache_compute_input = final_cos_concat_output - sin_cache_compute_input = final_sin_concat_output + # 4. Handle Type Casting + # RotaryEmbedding requires input, cos, sin to be same type. + # Qwen2.5-VL forces float32 computation. + force_fp32 = self.attention_attrs.get("rope_cast", {}).get("use_fp32", False) + compute_dtype = ir.DataType.FLOAT if force_fp32 else self.io_dtype + rope_input = transpose_in_out if force_fp32 and self.io_dtype != ir.DataType.FLOAT: - # Cast Q/K (self.io_dtype) up to float32 - q_or_k_cast_name = f"{basename}/q_or_k/Cast" - q_or_k_cast_output = f"{q_or_k_cast_name}/output_0" - self.make_cast( - q_or_k_cast_name, - transpose_1_output, - compute_dtype, - transpose_1_target_shape, - ) - q_or_k_compute_input = q_or_k_cast_output - elif not force_fp32 and self.io_dtype != ir.DataType.FLOAT: - # Cast Caches (float32) down to self.io_dtype - cos_cache_cast_name = f"{basename}/cos_final/Cast" - cos_cache_cast_output = f"{cos_cache_cast_name}/output_0" - self.make_cast( - cos_cache_cast_name, - final_cos_concat_output, - compute_dtype, - ["batch_size", 1, "sequence_length", self.head_size], - ) - cos_cache_compute_input = cos_cache_cast_output - - sin_cache_cast_name = f"{basename}/sin_final/Cast" - sin_cache_cast_output = f"{sin_cache_cast_name}/output_0" - self.make_cast( - sin_cache_cast_name, - final_sin_concat_output, - compute_dtype, - ["batch_size", 1, "sequence_length", self.head_size], - ) - sin_cache_compute_input = sin_cache_cast_output - - # Apply rotation: (q * cos) + (rotate_half(q) * sin) - - # 1. (q * cos) - mul_1_name = f"{basename}/Mul_1" - mul_1_output = f"{mul_1_name}/output_0" - self.make_mul( - mul_1_name, - [q_or_k_compute_input, cos_cache_compute_input], - compute_dtype, - transpose_1_target_shape, - ) - - # 2. rotate_half(q) - rotated_half_q_name = self.rotate_half(q_or_k_compute_input, transpose_1_target_shape, basename, compute_dtype) - - # 3. (rotate_half(q) * sin) - mul_2_name = f"{basename}/Mul_2" - mul_2_output = f"{mul_2_name}/output_0" - self.make_mul( - mul_2_name, - [rotated_half_q_name, sin_cache_compute_input], - compute_dtype, - transpose_1_target_shape, - ) - - # 4. (q * cos) + (rotate_half(q) * sin) - add_name = f"{basename}/add/Add" - add_output = f"{add_name}/output_0" - self.make_add( - add_name, - [mul_1_output, mul_2_output], - compute_dtype, - transpose_1_target_shape, + cast_in_node = f"{basename}/Input/Cast" + rope_input = f"{cast_in_node}/output_0" + self.make_cast(cast_in_node, transpose_in_out, compute_dtype, target_shape_bnsh) + + rope_cos = flat_cos + rope_sin = flat_sin + # Note: dyn_cos is Float. flat_cos is Float. If compute_dtype is not Float (e.g. fp16), we must cast cache. + if compute_dtype != ir.DataType.FLOAT: + # Cache is Float, we need FP16 + cast_cos_node = f"{basename}/Cos/Cast" + rope_cos = f"{cast_cos_node}/output_0" + self.make_cast(cast_cos_node, flat_cos, compute_dtype, ["total_token_count", self.head_size // 2]) + + cast_sin_node = f"{basename}/Sin/Cast" + rope_sin = f"{cast_sin_node}/output_0" + self.make_cast(cast_sin_node, flat_sin, compute_dtype, ["total_token_count", self.head_size // 2]) + + # 5. RotaryEmbedding Node + rope_node = f"{basename}/RotaryEmbedding" + rope_output = f"{rope_node}/output_0" + self.make_node( + "RotaryEmbedding", + [rope_input, pos_ids_out, rope_cos, rope_sin], + [rope_output], + name=rope_node, + domain="com.microsoft", + rotary_embedding_dim=self.head_size, + num_heads=num_heads, + interleaved=0, # False, matches rotate_half logic ) + self.make_value(rope_output, compute_dtype, target_shape_bnsh) - # --- End RoPE computation --- - - add_output_final = add_output + # 6. Post-process Output + # Cast back if needed -> Transpose -> Reshape + final_rope_output = rope_output if force_fp32 and self.io_dtype != ir.DataType.FLOAT: - # Cast result back down to self.io_dtype - add_cast_name = f"{basename}/add/Cast" - add_cast_output = f"{add_cast_name}/output_0" - self.make_cast(add_cast_name, add_output, output_dtype, transpose_1_target_shape) - add_output_final = add_cast_output - - # Transpose back: [B, N, S, H] -> [B, S, N, H] - transpose_2_name = f"{basename}/q_or_k_bnsh_to_bsnh/Transpose" - transpose_2_output = f"{transpose_2_name}/output_0" - self.make_transpose( - transpose_2_name, - add_output_final, - output_dtype, - reshape_1_target_shape_ort, - perm=[0, 2, 1, 3], - ) + cast_out_node = f"{basename}/Output/Cast" + final_rope_output = f"{cast_out_node}/output_0" + self.make_cast(cast_out_node, rope_output, self.io_dtype, target_shape_bnsh) - # Reshape back: [B, S, N, H] -> [B, S, N*H] - reshape_2_name = f"{basename}/q_or_k_bsnh_to_bsd/Reshape" - reshape_2_output = f"{reshape_2_name}/output_0" + transpose_out_node = f"{basename}/Output/Transpose" + transpose_out_out = f"{transpose_out_node}/output_0" + self.make_transpose(transpose_out_node, final_rope_output, self.io_dtype, + ["batch_size", "sequence_length", num_heads, self.head_size], [0, 2, 1, 3]) + + reshape_out_node = f"{basename}/Output/Reshape" + reshape_out_out = f"{reshape_out_node}/output_0" self.make_reshape( - reshape_2_name, - [ - transpose_2_output, - f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]", - ], - output_dtype, - q_or_k_shape, + reshape_out_node, + [transpose_out_out, f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]"], + self.io_dtype, + q_or_k_shape ) - return reshape_2_output + return reshape_out_out def make_attention(self, layer_id, attention, root_input, **kwargs): # Make nodes for the Attention subgraph (with MRoPE) @@ -867,4 +722,4 @@ def make_model(self, input_path): if not self.exclude_lm_head: # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model print("Reading LM head") - self.make_lm_head(hf_model.lm_head) + self.make_lm_head(hf_model.lm_head) \ No newline at end of file From 4ccd291e86725da1594ca66736f90308e09a25f9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 24 Nov 2025 20:10:20 +0000 Subject: [PATCH 15/18] comment --- src/python/py/models/builders/qwen.py | 81 +++++++++++++++++++++------ 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 8e67948e05..f08f2b334d 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -383,8 +383,40 @@ def process_cache(input_name, name_suffix): return flat_cos, flat_sin def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): - # Use the optimized RotaryEmbedding operator with a single node. + # Make nodes for the MRoPE rotation subgraph using RotaryEmbedding op # + # 1. Flatten 3D caches [3, B, S, H] -> [B*S, H/2] (via make_mrope_flattened_caches) + # 2. Generate linear position IDs [B, S] (0 .. B*S-1) + # 3. Apply RotaryEmbedding + # + # dyn_cos (3, B, S, H) dyn_sin (3, B, S, H) + # | | + # make_mrope_flattened_caches (slice, split, gather, concat, flatten) + # | | + # flat_cos flat_sin + # (B*S, H/2) (B*S, H/2) + # | | + # +-----------+----------+ + # | + # q_or_k | position_ids + # (B, S, N*H) | (0 .. B*S-1) + # | | | + # Reshape | Reshape + # | | | + # Transpose | | + # (B, N, S, H) | (B, S) + # | | | + # +--------+--------+--------+--------+ + # | | + # RotaryEmbedding (com.microsoft) + # | + # output (B, N, S, H) + # | + # Transpose + # | + # Reshape + # (B, S, N*H) + # 1. Prepare flattened MRoPE caches [B*S, H/2] # This slices, splits, and re-assembles the 3D dynamic caches into the correct per-token layout. flat_cos, flat_sin = self.make_mrope_flattened_caches(layer_id, dyn_cos, dyn_sin) @@ -398,13 +430,15 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Extract B and S batch_size_node = f"{basename}/BatchSize/Gather" batch_size_out = f"{batch_size_node}/output_0" - self.make_gather(batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, - [], 0) + self.make_gather( + batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [], 0 + ) seq_len_node = f"{basename}/SeqLen/Gather" seq_len_out = f"{seq_len_node}/output_0" - self.make_gather(seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], - 0) + self.make_gather( + seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], 0 + ) # Calculate Total Tokens = B * S mul_len_node = f"{basename}/TotalLen/Mul" @@ -415,22 +449,32 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn # Range(0, TotalTokens) range_node = f"{basename}/Range" range_out = f"{range_node}/output_0" - self.make_node("Range", ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"], [range_out], - name=range_node) + self.make_node( + "Range", ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"], [range_out], name=range_node + ) self.make_value(range_out, ir.DataType.INT64, ["total_token_count"]) # Slice Position IDs shape from input shape (take first 2 dims) slice_shape_node = f"{basename}/SliceShape" slice_shape_out = f"{slice_shape_node}/output_0" - self.make_slice(slice_shape_node, - [f"{shape_node}/output_0", "/model/constants/INT64/[0]", "/model/constants/INT64/[2]", - "/model/constants/INT64/[0]"], ir.DataType.INT64, [2]) + self.make_slice( + slice_shape_node, + [ + f"{shape_node}/output_0", + "/model/constants/INT64/[0]", + "/model/constants/INT64/[2]", + "/model/constants/INT64/[0]", + ], + ir.DataType.INT64, + [2], + ) # Reshape Range output to [B, S] pos_ids_reshape_node = f"{basename}/PosIds/Reshape" pos_ids_out = f"{pos_ids_reshape_node}/output_0" - self.make_reshape(pos_ids_reshape_node, [range_out, slice_shape_out], ir.DataType.INT64, - ["batch_size", "sequence_length"]) + self.make_reshape( + pos_ids_reshape_node, [range_out, slice_shape_out], ir.DataType.INT64, ["batch_size", "sequence_length"] + ) # 3. Prepare Q/K input [B, N, S, H] # Input is [B, S, N*H]. Reshape -> [B, S, N, H] -> Transpose -> [B, N, S, H] @@ -498,8 +542,13 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn transpose_out_node = f"{basename}/Output/Transpose" transpose_out_out = f"{transpose_out_node}/output_0" - self.make_transpose(transpose_out_node, final_rope_output, self.io_dtype, - ["batch_size", "sequence_length", num_heads, self.head_size], [0, 2, 1, 3]) + self.make_transpose( + transpose_out_node, + final_rope_output, + self.io_dtype, + ["batch_size", "sequence_length", num_heads, self.head_size], + [0, 2, 1, 3], + ) reshape_out_node = f"{basename}/Output/Reshape" reshape_out_out = f"{reshape_out_node}/output_0" @@ -507,7 +556,7 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn reshape_out_node, [transpose_out_out, f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]"], self.io_dtype, - q_or_k_shape + q_or_k_shape, ) return reshape_out_out @@ -722,4 +771,4 @@ def make_model(self, input_path): if not self.exclude_lm_head: # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model print("Reading LM head") - self.make_lm_head(hf_model.lm_head) \ No newline at end of file + self.make_lm_head(hf_model.lm_head) From e04e4df355336ce43cf8435849ccaef9d2153dfb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 26 Nov 2025 15:34:09 -0800 Subject: [PATCH 16/18] refactoring --- src/python/py/models/builders/base.py | 30 ++++-- src/python/py/models/builders/qwen.py | 144 ++++++++++---------------- 2 files changed, 74 insertions(+), 100 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 23b5252807..4db9319959 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -2783,7 +2783,11 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # O_MatMul # | # O_Add + self.make_attention_input_proj(layer_id, attention, root_input, **kwargs) + self.make_attention_qk_subgraph(layer_id, attention, root_input, **kwargs) + self.make_attention_output_proj(layer_id, attention, root_input, **kwargs) + def make_attention_input_proj(self, layer_id, attention, root_input, **kwargs): # Unpack attention weights if needed self.make_attention_unpacked(layer_id, attention, root_input, **kwargs) @@ -2859,6 +2863,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): ) self.attention_attrs["v_path"] = f"{v_add_name}/output_0" + def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): # Make Q/K SimplifiedLayerNorm nodes if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]: self.make_qk_norm(layer_id, attention) @@ -2920,11 +2925,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): **kwargs, ) + def make_attention_output_proj(self, layer_id, attention, root_input, **kwargs): + attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" + attn_output = f"{attn_name}/output_0" + # Make MatMul node (output projection weight node) o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense" o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" o_weight = getattr(attention, o_proj) - o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") + o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, attn_output) # Make Add node (output projection bias node if bias exists) o_bias_exists = getattr(attention, o_proj).bias is not None @@ -3799,13 +3808,7 @@ def make_layer(self, layer_id, layer): # Norm after last decoder layer of model (last layer --> norm) self.layernorm_attrs["last_layernorm"] = True - def make_model(self, input_path): - # Make inputs and outputs to ONNX model - self.make_inputs_and_outputs() - - # Make pre-processing nodes - self.make_preprocessing_nodes() - + def load_weights(self, input_path): # Load weights of original model if input_path.endswith(".gguf"): # Load GGUF model @@ -3859,6 +3862,17 @@ def make_model(self, input_path): model = PeftModel.from_pretrained( model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token ) + + return model + + def make_model(self, input_path): + # Make inputs and outputs to ONNX model + self.make_inputs_and_outputs() + + # Make pre-processing nodes + self.make_preprocessing_nodes() + + model = self.load_weights(input_path) # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index f08f2b334d..c9116e71bb 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -279,16 +279,36 @@ def make_dynamic_rope_caches(self, layer_id, basename): return cos_output, sin_output def make_mrope_flattened_caches(self, layer_id, dyn_cos, dyn_sin): - """ - Converts the 3D MRoPE caches [3, B, S, H] into flattened, interleaved caches [B*S, H/2] - suitable for the RotaryEmbedding operator. - The logic is: - 1. Slice dynamic caches to H/2. - 2. Split into 3 chunks based on mrope_sections (e.g. 16, 24, 24). - 3. Gather Temporal(0), Height(1), Width(2) specific slices for each chunk. - 4. Concat back to H/2. - 5. Flatten to [B*S, H/2]. - """ + # Converts the 3D MRoPE caches [3, B, S, H] into flattened, interleaved caches [B*S, H/2] + # suitable for the RotaryEmbedding operator. + # The logic is: + # 1. Slice dynamic caches to H/2. + # 2. Split into 3 chunks based on mrope_sections (e.g. 16, 24, 24). + # 3. Gather Temporal(0), Height(1), Width(2) specific slices for each chunk. + # 4. Concat back to H/2. + # 5. Flatten to [B*S, H/2]. + # The subgraph looks like: + # dyn_cos (3, B, S, H) + # | + # Slice_Half + # (3, B, S, H/2) + # | + # Split + # (3, B, S, sections[i]) + # / | \ + # Gather Gather Gather + # idx=0 idx=1 idx=2 + # / | \ + # Squeeze Squeeze Squeeze + # \ | / + # \ | / + # \ | / + # Concat + # (B, S, H/2) + # | + # Reshape + # (B*S, H/2) + basename = f"/model/layers.{layer_id}/attn/mrope_flattened_cache" def process_cache(input_name, name_suffix): @@ -561,15 +581,10 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn return reshape_out_out - def make_attention(self, layer_id, attention, root_input, **kwargs): + def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): # Make nodes for the Attention subgraph (with MRoPE) # - # root_input - # / | \ - # / | \ - # Q_MatMul K_MatMul V_MatMul - # | | | - # Q_Add K_Add V_Add + # q_path k_path v_path # | | | # | | +-----------------+ # | | | @@ -591,63 +606,20 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # | | # GroupQueryAttention <--------------+ # | - # O_MatMul - # | - # O_Add - - # 1. Unpack QKV if necessary (e.g. qkv_proj) - super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) - - # 2. Build Q/K/V MatMul and Add nodes - q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" - q_matmul_name = self.make_matmul(attention.q_proj, q_matmul_basename, root_input) - self.attention_attrs["q_path"] = f"{q_matmul_name}/output_0" + + # 1. Calculate shapes for MRoPE rotation q_shape = [ "batch_size", "sequence_length", self.num_attn_heads * self.head_size, ] + k_shape = [ + "batch_size", + "sequence_length", + self.num_kv_heads * self.head_size, + ] - k_matmul_basename = f"/model/layers.{layer_id}/attn/k_proj/MatMul" - k_matmul_name = self.make_matmul(attention.k_proj, k_matmul_basename, root_input) - self.attention_attrs["k_path"] = f"{k_matmul_name}/output_0" - k_shape = ["batch_size", "sequence_length", self.num_kv_heads * self.head_size] - - v_matmul_basename = f"/model/layers.{layer_id}/attn/v_proj/MatMul" - v_matmul_name = self.make_matmul(attention.v_proj, v_matmul_basename, root_input) - self.attention_attrs["v_path"] = f"{v_matmul_name}/output_0" - - # Handle biases - q_bias_exists = attention.q_proj.bias is not None and torch.count_nonzero(attention.q_proj.bias) > 0 - k_bias_exists = attention.k_proj.bias is not None and torch.count_nonzero(attention.k_proj.bias) > 0 - v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0 - - if q_bias_exists: - q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias( - attention.q_proj.bias, - q_add_name, - root_input=self.attention_attrs["q_path"], - ) - self.attention_attrs["q_path"] = f"{q_add_name}/output_0" - if k_bias_exists: - k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias( - attention.k_proj.bias, - k_add_name, - root_input=self.attention_attrs["k_path"], - ) - self.attention_attrs["k_path"] = f"{k_add_name}/output_0" - if v_bias_exists: - v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias( - attention.v_proj.bias, - v_add_name, - root_input=self.attention_attrs["v_path"], - ) - self.attention_attrs["v_path"] = f"{v_add_name}/output_0" - - # 3. Apply 3D RoPE (MRoPE) + # 2. Apply 3D RoPE (MRoPE) cos_dynamic, sin_dynamic = self.make_dynamic_rope_caches( layer_id, basename=f"/model/layers.{layer_id}/attn/mrope_dynamic_cache" ) @@ -674,7 +646,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): basename=f"/model/layers.{layer_id}/attn/k_mrope", ) - # 4. Call GroupQueryAttention op + # 3. Call GroupQueryAttention op past_k = f"past_key_values.{layer_id}.key" past_v = f"past_key_values.{layer_id}.value" present_k = f"present.{layer_id}.key" @@ -696,28 +668,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): **kwargs, ) - # 5. Build O-proj - o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense" - o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" - o_weight = getattr(attention, o_proj) - o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") - - o_bias_exists = getattr(attention, o_proj).bias is not None - if o_bias_exists: - o_add_name = f"/model/layers.{layer_id}/attn/o_proj/Add" - o_bias = getattr(attention, o_proj).bias - self.make_add_bias(o_bias, o_add_name, root_input=f"{o_matmul_name}/output_0") - self.layernorm_attrs["skip_input"] = f"{o_add_name}/output_0" - else: - self.layernorm_attrs["skip_input"] = f"{o_matmul_name}/output_0" - - def make_model(self, input_path): - # Make inputs and outputs to ONNX model - self.make_inputs_and_outputs() - - # Make pre-processing nodes - self.make_preprocessing_nodes() - + def load_weights(self, input_path): # Load the Hugging Face model print("Loading Qwen2_5_VLForConditionalGeneration model...") hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( @@ -728,10 +679,19 @@ def make_model(self, input_path): trust_remote_code=self.hf_remote, ) + def make_model(self, input_path): + # Make inputs and outputs to ONNX model + self.make_inputs_and_outputs() + + # Make pre-processing nodes + self.make_preprocessing_nodes() + + hf_model = self.load_weights(input_path) + # We only want to export the text model model = hf_model.language_model print(f"Isolated language_model ({model.__class__.__name__}) for ONNX export.") - + # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 From 88964ad6043b76eb22c351ef574ee2ea732203e2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 1 Dec 2025 22:42:58 +0000 Subject: [PATCH 17/18] review feedback --- src/python/py/models/builders/base.py | 269 ++++---------------------- src/python/py/models/builders/qwen.py | 26 +-- 2 files changed, 56 insertions(+), 239 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 4db9319959..c2ccdc9640 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -640,12 +640,7 @@ def make_key_value_cache_shape(self, layer_id, shape): For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name. """ if self.ep == "trt-rtx" and hasattr(self, "is_local") and self.is_local(layer_id): - return [ - shape[0], - shape[1], - shape[2].replace("sequence", "sliding"), - shape[3], - ] + return [shape[0], shape[1], shape[2].replace("sequence", "sliding"), shape[3]] return shape def save_processing(self, model_name_or_path, extra_kwargs, out_dir): @@ -770,16 +765,7 @@ def tensor_func(): value.const_value = ir_tensor self.model.graph.register_initializer(value) - def make_node( - self, - op_type, - inputs: Sequence[str], - outputs: Sequence[str], - *, - name: str, - domain="", - **kwargs, - ): + def make_node(self, op_type, inputs: Sequence[str], outputs: Sequence[str], *, name: str, domain="", **kwargs): assert name, "Node name must be provided" if name in self.node_names: # Note: @@ -801,14 +787,7 @@ def make_node( # Resolve values from names input_values = [self.make_value(name) for name in inputs] output_values = [self.make_value(name) for name in outputs] - node = ir.node( - op_type, - inputs=input_values, - attributes=kwargs, - domain=domain, - outputs=output_values, - name=name, - ) + node = ir.node(op_type, inputs=input_values, attributes=kwargs, domain=domain, outputs=output_values, name=name) self.model.graph.append(node) self.node_names.add(name) @@ -854,13 +833,7 @@ def make_inputs_and_outputs(self): # Add KV cache to inputs key_name = f"past_key_values.{i}.key" key_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.key"]) - inputs.append( - self.make_value( - key_name, - dtype=self.input_types["past_key_values.key"], - shape=key_shape, - ) - ) + inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=key_shape)) value_name = f"past_key_values.{i}.value" value_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.value"]) @@ -875,13 +848,7 @@ def make_inputs_and_outputs(self): value_name = f"present.{i}.value" value_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.value"]) - outputs.append( - self.make_value( - value_name, - dtype=self.output_types["present.value"], - shape=value_shape, - ) - ) + outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=value_shape)) def make_constant(self, name): # Make constant ops for 0, 1, 2, 3, etc. @@ -914,13 +881,7 @@ def make_shape(self, name, root_input, shape): def make_constant_of_shape(self, name, root_input, value, dtype, shape): output = f"{name}/output_0" - self.make_node( - "ConstantOfShape", - inputs=[root_input], - outputs=[output], - name=name, - value=value, - ) + self.make_node("ConstantOfShape", inputs=[root_input], outputs=[output], name=name, value=value) self.make_value(output, dtype, shape=shape) def make_unsqueeze(self, name, inputs, dtype, shape): @@ -1067,11 +1028,7 @@ def make_matmul(self, matmul, basename, root_input, **kwargs): return self.make_matmul_op(matmul, basename, root_input, **kwargs) def make_matmul_op(self, matmul, basename, root_input, **kwargs): - if self.onnx_dtype in { - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - ir.DataType.FLOAT, - }: + if self.onnx_dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16, ir.DataType.FLOAT}: return self.make_matmul_float(matmul, basename, root_input, **kwargs) elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: if self.quant_attrs["use_qdq"]: @@ -1256,11 +1213,7 @@ def make_matmul_lora(self, matmul, basename, root_input, **kwargs): return add_name def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs): - if self.onnx_dtype in { - ir.DataType.FLOAT, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - }: + if self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: return self.make_packed_matmul_float(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}: return self.make_packed_matmul_int4(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) @@ -1712,10 +1665,7 @@ def make_rotary_embedding_caches_from_scratch(self): freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - cos_cache, sin_cache = ( - emb.cos() * self.rope_attrs["mscale"], - emb.sin() * self.rope_attrs["mscale"], - ) + cos_cache, sin_cache = emb.cos() * self.rope_attrs["mscale"], emb.sin() * self.rope_attrs["mscale"] return cos_cache, sin_cache def make_rotary_embedding_caches(self, **kwargs): @@ -1908,12 +1858,7 @@ def make_rotary_embedding(self, name, root_input, **kwargs): cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches() num_heads = self.num_kv_heads if "k_rotary" in name else self.num_attn_heads - inputs = [ - root_input, - kwargs.pop("position_ids"), - cos_cache_name, - sin_cache_name, - ] + inputs = [root_input, kwargs.pop("position_ids"), cos_cache_name, sin_cache_name] output = f"{name}/output_0" self.make_node( "RotaryEmbedding", @@ -1939,10 +1884,7 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.rope_attrs["mscale"] = self.rope_attrs["multi_cache"]["long_mscale"] # Create caches for when sequence_length > self.original_context_length - cos_cache_large_name, sin_cache_large_name = ( - "cos_cache_large", - "sin_cache_large", - ) + cos_cache_large_name, sin_cache_large_name = "cos_cache_large", "sin_cache_large" self.rope_attrs["save_caches"] = False cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches( cos_cache_name=cos_cache_large_name, sin_cache_name=sin_cache_large_name @@ -1955,10 +1897,7 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.rope_attrs["create_caches"] = True # Create caches for when sequence_length <= self.original_context_length - cos_cache_small_name, sin_cache_small_name = ( - "cos_cache_small", - "sin_cache_small", - ) + cos_cache_small_name, sin_cache_small_name = "cos_cache_small", "sin_cache_small" self.rope_attrs["save_caches"] = False cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches( cos_cache_name=cos_cache_small_name, sin_cache_name=sin_cache_small_name @@ -1994,10 +1933,7 @@ def make_rotary_embedding_multi_cache(self, **kwargs): gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" greater_name = f"{basename}/Greater" - greater_inputs = [ - f"{gather_name}/output_0", - f"/model/constants/INT64/{self.original_context_length}", - ] + greater_inputs = [f"{gather_name}/output_0", f"/model/constants/INT64/{self.original_context_length}"] self.make_greater(greater_name, greater_inputs, shape=[]) # Create split If nodes and return early @@ -2033,10 +1969,7 @@ def make_rotary_embedding_multi_cache(self, **kwargs): gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" greater_name = f"{basename}/Greater" - greater_inputs = [ - f"{gather_name}/output_0", - f"/model/constants/INT64/{self.original_context_length}", - ] + greater_inputs = [f"{gather_name}/output_0", f"/model/constants/INT64/{self.original_context_length}"] self.make_greater(greater_name, greater_inputs, shape=[]) if_name = f"{basename}/If" @@ -2162,13 +2095,7 @@ def _make_skip_layer_norm( kwargs = {"epsilon": self.layernorm_attrs["epsilon"]} kwargs.update({"axis": -1, "stash_type": 1}) - self.make_node( - "LayerNormalization", - inputs=inputs, - outputs=[output_0], - name=make_layer_norm_name, - **kwargs, - ) + self.make_node("LayerNormalization", inputs=inputs, outputs=[output_0], name=make_layer_norm_name, **kwargs) self.make_value(output_0, io_dtype, shape=shape) # This expansion contrib-op can be updated / deprecated in the future. @@ -2252,21 +2179,14 @@ def make_qk_norm(self, layer_id, attention): # Reshape (BxSxD) # Save kwargs shared by LayerNorm ops and precision types to use - layernorm_kwargs = { - "epsilon": self.layernorm_attrs["epsilon"], - "axis": -1, - "stash_type": 1, - } + layernorm_kwargs = {"epsilon": self.layernorm_attrs["epsilon"], "axis": -1, "stash_type": 1} old_io_dtype = self.io_dtype new_io_dtype = ir.DataType.FLOAT if self.layernorm_attrs["cast"]["use_fp32"] else self.io_dtype cast = old_io_dtype != new_io_dtype # Reshape Q MatMul from BxSxD to Bx(SxN)xH before LayerNorm q_reshape_1_name = f"/model/layers.{layer_id}/attn/q_norm/Reshape_1" - q_reshape_1_inputs = [ - self.attention_attrs["q_path"], - f"/model/constants/INT64/[0, -1, {self.head_size}]", - ] + q_reshape_1_inputs = [self.attention_attrs["q_path"], f"/model/constants/INT64/[0, -1, {self.head_size}]"] q_reshape_1_output = f"{q_reshape_1_name}/output_0" self.make_reshape( q_reshape_1_name, @@ -2319,10 +2239,7 @@ def make_qk_norm(self, layer_id, attention): # Reshape K MatMul from BxSxD to Bx(SxN)xH before LayerNorm k_reshape_1_name = f"/model/layers.{layer_id}/attn/k_norm/Reshape_1" - k_reshape_1_inputs = [ - self.attention_attrs["k_path"], - f"/model/constants/INT64/[0, -1, {self.head_size}]", - ] + k_reshape_1_inputs = [self.attention_attrs["k_path"], f"/model/constants/INT64/[0, -1, {self.head_size}]"] k_reshape_1_output = f"{k_reshape_1_name}/output_0" self.make_reshape( k_reshape_1_name, @@ -2478,13 +2395,7 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): ) concat_1_name = f"{basename}/Concat_1" concat_1_inputs = [past_kv, f"{transpose_1_name}/output_0"] - self.make_node( - "Concat", - inputs=concat_1_inputs, - outputs=[present_kv], - name=concat_1_name, - axis=2, - ) + self.make_node("Concat", inputs=concat_1_inputs, outputs=[present_kv], name=concat_1_name, axis=2) shape_1_name = f"{basename}/Shape_1" self.make_shape(shape_1_name, present_kv, shape=[4]) @@ -2565,11 +2476,7 @@ def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): equal_inputs = [f"{reshape_2_name}/output_0", f"{mul_2_name}/output_0"] self.make_equal(equal_name, equal_inputs, shape=[5]) where_name = f"{basename}/Where" - where_inputs = [ - f"{equal_name}/output_0", - f"{constant_shape_name}/output_0", - f"{reshape_2_name}/output_0", - ] + where_inputs = [f"{equal_name}/output_0", f"{constant_shape_name}/output_0", f"{reshape_2_name}/output_0"] self.make_where(where_name, where_inputs, dtype=ir.DataType.INT64, shape=[5]) # Make the final nodes @@ -2840,27 +2747,15 @@ def make_attention_input_proj(self, layer_id, attention, root_input, **kwargs): else: if q_bias_exists: q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias( - attention.q_proj.bias, - q_add_name, - root_input=self.attention_attrs["q_path"], - ) + self.make_add_bias(attention.q_proj.bias, q_add_name, root_input=self.attention_attrs["q_path"]) self.attention_attrs["q_path"] = f"{q_add_name}/output_0" if k_bias_exists: k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias( - attention.k_proj.bias, - k_add_name, - root_input=self.attention_attrs["k_path"], - ) + self.make_add_bias(attention.k_proj.bias, k_add_name, root_input=self.attention_attrs["k_path"]) self.attention_attrs["k_path"] = f"{k_add_name}/output_0" if v_bias_exists: v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias( - attention.v_proj.bias, - v_add_name, - root_input=self.attention_attrs["v_path"], - ) + self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"]) self.attention_attrs["v_path"] = f"{v_add_name}/output_0" def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): @@ -3619,29 +3514,11 @@ def make_gelu(self, layer_id, root_input, activation): output = f"{gelu_name}/output_0" if activation == "Gelu": - self.make_node( - "Gelu", - inputs=[root_input], - outputs=[output], - name=gelu_name, - approximate="none", - ) + self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="none") elif activation == "FastGelu": - self.make_node( - "Gelu", - inputs=[root_input], - outputs=[output], - name=gelu_name, - approximate="tanh", - ) + self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="tanh") else: - self.make_node( - activation, - inputs=[root_input], - outputs=[output], - name=gelu_name, - domain="com.microsoft", - ) + self.make_node(activation, inputs=[root_input], outputs=[output], name=gelu_name, domain="com.microsoft") self.make_value(output, self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) @@ -3692,13 +3569,7 @@ def make_lm_head(self, lm_head): # List order matters here. It should match the order of the below if condition checks. # Add new checks to the end of the list and after the below if condition checks. - exists_checks = [ - bias_exists, - scale_exists, - mask_exists, - softcap_exists, - cast_exists, - ] + exists_checks = [bias_exists, scale_exists, mask_exists, softcap_exists, cast_exists] matmul_basename = "/lm_head/MatMul" root_input = self.layernorm_attrs["output_0"] @@ -3862,9 +3733,9 @@ def load_weights(self, input_path): model = PeftModel.from_pretrained( model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token ) - + return model - + def make_model(self, input_path): # Make inputs and outputs to ONNX model self.make_inputs_and_outputs() @@ -3953,12 +3824,7 @@ def has_final_norm(self, module, orig_model): # GGUF names (all models loaded with GGUFModel.from_pretrained) gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm - hf_names = [ - hf_norm, - hf_final_layernorm, - hf_transformer_final_layernorm, - hf_language_model_norm, - ] + hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm, hf_language_model_norm] gguf_names = [gguf_final_norm] return any(hf_names + gguf_names) @@ -4101,16 +3967,10 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): # | # Unsqueeze shared_add_name = f"{basename}/Add_1" - shared_add_inputs = [ - f"{basename}/Gather_2/output_0", - f"{past_key_gather_name}/output_0", - ] + shared_add_inputs = [f"{basename}/Gather_2/output_0", f"{past_key_gather_name}/output_0"] self.make_add(shared_add_name, shared_add_inputs, dtype=ir.DataType.INT64, shape=[]) unsqueeze_3_name = f"{basename}/Unsqueeze_3" # shared unsqueeze for input_ids and past_key_values.0.key - unsqueeze_3_inputs = [ - f"{shared_add_name}/output_0", - "/model/constants/INT64/[0]", - ] + unsqueeze_3_inputs = [f"{shared_add_name}/output_0", "/model/constants/INT64/[0]"] self.make_unsqueeze(unsqueeze_3_name, unsqueeze_3_inputs, dtype=ir.DataType.INT64, shape=[1]) # Make the additional subgraph for input_ids @@ -4120,10 +3980,7 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): # Gather (idx=1) --> Concat --> ConstantOfShape Reshape --> Less --> Where --> Unsqueeze --> Unsqueeze --> Expand # \ / \ | # Unsqueeze (unsqueeze_5) Shape --> Slice --> Squeeze --> Range --> Add -------+ - unsqueeze_inputs = [ - f"{basename}/Gather_2/output_0", - "/model/constants/INT64/[0]", - ] + unsqueeze_inputs = [f"{basename}/Gather_2/output_0", "/model/constants/INT64/[0]"] unsqueeze_4_name = f"{basename}/Unsqueeze_4" self.make_unsqueeze(unsqueeze_4_name, unsqueeze_inputs, dtype=ir.DataType.INT64, shape=[1]) unsqueeze_5_name = f"{basename}/Unsqueeze_5" @@ -4162,10 +4019,7 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): squeeze_1_inputs = [f"{slice_1_name}/output_0", "/model/constants/INT64/[0]"] self.make_squeeze(squeeze_1_name, squeeze_1_inputs, dtype=ir.DataType.INT64, shape=[]) unsqueeze_7_name = f"{basename}/output_0" - unsqueeze_7_inputs = [ - f"{squeeze_1_name}/output_0", - "/model/constants/INT64/[0]", - ] + unsqueeze_7_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/INT64/[0]"] self.make_unsqueeze(unsqueeze_7_name, unsqueeze_7_inputs, dtype=ir.DataType.INT64, shape=[1]) concat_3_name = f"{basename}/Concat_3" concat_3_inputs = [f"{unsqueeze_7_name}/output_0", "/model/constants/INT64/[1]"] @@ -4186,11 +4040,7 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): squeeze_2_inputs = [f"{slice_2_name}/output_0", "/model/constants/INT64/[0]"] self.make_squeeze(squeeze_2_name, squeeze_2_inputs, dtype=ir.DataType.INT64, shape=[]) range_name = f"{basename}/Range" - range_inputs = [ - "/model/constants/INT64/0", - f"{squeeze_2_name}/output_0", - "/model/constants/INT64/1", - ] + range_inputs = ["/model/constants/INT64/0", f"{squeeze_2_name}/output_0", "/model/constants/INT64/1"] self.make_range(range_name, range_inputs) add_2_name = f"{basename}/Add_2" add_inputs = [f"{range_name}/output_0", "/model/constants/INT64/1"] @@ -4214,10 +4064,7 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): unsqueeze_8_inputs = [f"{where_2_name}/output_0", "/model/constants/INT64/[0]"] self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=self.io_dtype, shape=None) unsqueeze_9_name = f"{basename}/Unsqueeze_9" - unsqueeze_9_inputs = [ - f"{unsqueeze_8_name}/output_0", - "/model/constants/INT64/[1]", - ] + unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/INT64/[1]"] self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) expand_name = self.make_common_mask_reformat_subgraph( @@ -4260,30 +4107,12 @@ def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): # | | # Expand --> Cast --> Sub --> Cast --> Where cast_1_name = f"{basename}/Cast_1" - self.make_cast( - cast_1_name, - f"{expand_name}/output_0", - dtype=self.io_dtype, - shape=["unk", "unk", "unk", "unk"], - ) + self.make_cast(cast_1_name, f"{expand_name}/output_0", dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) sub_name = f"{basename}/Sub" - sub_inputs = [ - f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1", - f"{cast_1_name}/output_0", - ] - self.make_sub( - sub_name, - sub_inputs, - dtype=self.io_dtype, - shape=["unk", "unk", "unk", "unk"], - ) + sub_inputs = [f"/model/constants/{self.to_str_dtype(self.io_dtype)}/1", f"{cast_1_name}/output_0"] + self.make_sub(sub_name, sub_inputs, dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) cast_2_name = f"{basename}/Cast_2" - self.make_cast( - cast_2_name, - f"{sub_name}/output_0", - dtype=ir.DataType.BOOL, - shape=["unk", "unk", "unk", "unk"], - ) + self.make_cast(cast_2_name, f"{sub_name}/output_0", dtype=ir.DataType.BOOL, shape=["unk", "unk", "unk", "unk"]) where_2_name = f"{basename}/Where_2" where_2_inputs = [ f"{cast_2_name}/output_0", @@ -4334,17 +4163,9 @@ def make_common_mask_reformat_subgraph( # Expand shape_1_name = f"{basename}/Shape_1" - self.make_shape( - shape_1_name, - root_input, - shape=[3] if self.exclude_embeds and input_ids_subgraph else [2], - ) + self.make_shape(shape_1_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) shape_2_name = f"{basename}/Shape_2" - self.make_shape( - shape_2_name, - root_input, - shape=[3] if self.exclude_embeds and input_ids_subgraph else [2], - ) + self.make_shape(shape_2_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) gather_1_name = f"{basename}/Gather_1" gather_1_inputs = [f"{shape_1_name}/output_0", "/model/constants/INT64/0"] self.make_gather(gather_1_name, gather_1_inputs, dtype=ir.DataType.INT64, shape=[], axis=0) @@ -4388,11 +4209,7 @@ def make_common_mask_reformat_subgraph( self.make_equal(equal_name, equal_inputs, shape=[4]) where_name = f"{basename}/Where_1" - where_inputs = [ - f"{equal_name}/output_0", - f"{constant_shape_name}/output_0", - f"{concat_name}/output_0", - ] + where_inputs = [f"{equal_name}/output_0", f"{constant_shape_name}/output_0", f"{concat_name}/output_0"] self.make_where(where_name, where_inputs, dtype=ir.DataType.INT64, shape=[4]) expand_name = f"{basename}/Expand" expand_inputs = [f"{unsqueeze_for_expand}/output_0", f"{where_name}/output_0"] diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index c9116e71bb..d37e62060e 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -61,13 +61,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op self.attention_attrs["use_rope_in_attn"] = False + # We need separate Q, K, V tensors to apply MRoPE manually. + # Packed MatMul provides a single output which would require splitting. + self.attention_attrs["use_packed_matmul"] = False + if "position_ids" not in self.input_names: print("Re-adding 'position_ids' to self.input_names.") - if "attention_mask" in self.input_names: - idx = self.input_names.index("attention_mask") - self.input_names.insert(idx + 1, "position_ids") - else: - self.input_names.append("position_ids") + self.input_names.append("position_ids") self.mrope_sections = self.rope_attrs.get("mrope", {}).get("sections", []) if not self.mrope_sections: @@ -290,10 +290,10 @@ def make_mrope_flattened_caches(self, layer_id, dyn_cos, dyn_sin): # The subgraph looks like: # dyn_cos (3, B, S, H) # | - # Slice_Half + # Slice # (3, B, S, H/2) # | - # Split + # Split # (3, B, S, sections[i]) # / | \ # Gather Gather Gather @@ -303,17 +303,17 @@ def make_mrope_flattened_caches(self, layer_id, dyn_cos, dyn_sin): # \ | / # \ | / # \ | / - # Concat + # Concat # (B, S, H/2) # | # Reshape # (B*S, H/2) - + basename = f"/model/layers.{layer_id}/attn/mrope_flattened_cache" def process_cache(input_name, name_suffix): # 1. Slice to H/2: [3, B, S, H] -> [3, B, S, H/2] - slice_name = f"{basename}/{name_suffix}/Slice_Half" + slice_name = f"{basename}/{name_suffix}/half/Slice" slice_output = f"{slice_name}/output_0" self.make_slice( slice_name, @@ -606,7 +606,7 @@ def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): # | | # GroupQueryAttention <--------------+ # | - + # 1. Calculate shapes for MRoPE rotation q_shape = [ "batch_size", @@ -671,7 +671,7 @@ def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): def load_weights(self, input_path): # Load the Hugging Face model print("Loading Qwen2_5_VLForConditionalGeneration model...") - hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + return Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_name_or_path, config=self.config, cache_dir=self.cache_dir, @@ -691,7 +691,7 @@ def make_model(self, input_path): # We only want to export the text model model = hf_model.language_model print(f"Isolated language_model ({model.__class__.__name__}) for ONNX export.") - + # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 From 5494d12e27168aa71e4bfea137e54225b58c8bdf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 1 Dec 2025 23:56:02 +0000 Subject: [PATCH 18/18] remove make_model from qwen --- src/python/py/models/builders/base.py | 1 - src/python/py/models/builders/qwen.py | 55 --------------------------- 2 files changed, 56 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index c2ccdc9640..8e92da6427 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -50,7 +50,6 @@ def parse_hf_token(hf_token): class Model: def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): - self.config = config self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings self.original_context_length = ( config.original_max_position_embeddings diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index d37e62060e..b46de99e30 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -673,62 +673,7 @@ def load_weights(self, input_path): print("Loading Qwen2_5_VLForConditionalGeneration model...") return Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_name_or_path, - config=self.config, cache_dir=self.cache_dir, token=self.hf_token, trust_remote_code=self.hf_remote, ) - - def make_model(self, input_path): - # Make inputs and outputs to ONNX model - self.make_inputs_and_outputs() - - # Make pre-processing nodes - self.make_preprocessing_nodes() - - hf_model = self.load_weights(input_path) - - # We only want to export the text model - model = hf_model.language_model - print(f"Isolated language_model ({model.__class__.__name__}) for ONNX export.") - - # Loop through model and map each module to ONNX/ORT ops - self.layer_id = 0 - - # The base.Model.make_model() loop expects modules from a standard causal LM, - # so we replicate its logic here but point to the correct modules in the hf_model - - # Handle Embeddings - if not self.exclude_embeds: - print("Reading embedding layer") - # The text model's embeddings are at model.embed_tokens - self.make_embedding(model.embed_tokens.weight) - else: - # When excluding embeds, the input is `inputs_embeds` - print("Skipping embedding layer, model will expect 'inputs_embeds'.") - self.layernorm_attrs["root_input"] = "inputs_embeds" - self.layernorm_attrs["skip_input"] = "inputs_embeds" - - # Handle Decoder Layers - for layer in model.layers: - if self.layer_id < self.num_layers: - print(f"Reading decoder layer {self.layer_id}") - self.make_layer(self.layer_id, layer) - self.layer_id += 1 - - # Handle Final Norm - if self.layer_id == self.num_layers and hasattr(model, "norm"): - print("Reading final norm") - self.make_layernorm( - self.layer_id, - model.norm, - skip=True, - simple=self.layernorm_attrs["simple"], - location="final_norm", - ) - - # Handle LM Head - if not self.exclude_lm_head: - # The LM head is part of the parent Qwen2_5_VLForConditionalGeneration model - print("Reading LM head") - self.make_lm_head(hf_model.lm_head)