diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 9cf463af91..a987b0cfd7 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -36,6 +36,7 @@ Phi4MMModel, PhiModel, Qwen3Model, + Qwen25VLTextModel, QwenModel, SmolLM3Model, ) @@ -161,7 +162,15 @@ def set_onnx_dtype(precision: str, extra_options: dict[str, Any]) -> ir.DataType @torch.no_grad -def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options): +def create_model( + model_name, + input_path, + output_dir, + precision, + execution_provider, + cache_dir, + **extra_options, +): if execution_provider == "NvTensorRtRtx": execution_provider = "trt-rtx" extra_options["use_qdq"] = True @@ -181,7 +190,10 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid from peft import PeftConfig peft_config = PeftConfig.from_pretrained( - extra_options["adapter_path"], token=hf_token, trust_remote_code=hf_remote, **extra_kwargs + extra_options["adapter_path"], + token=hf_token, + trust_remote_code=hf_remote, + **extra_kwargs, ) config.update(peft_config.__dict__) @@ -292,6 +304,16 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "SmolLM3ForCausalLM": onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration": + text_config = config.text_config + for key in text_config: + if not hasattr(config, key): + setattr(config, key, getattr(text_config, key)) + print( + "WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default." + ) + extra_options["exclude_embeds"] = True + onnx_model = Qwen25VLTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config_only: # Create base Model class to guess model attributes onnx_model = Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) diff --git a/src/python/py/models/builders/__init__.py b/src/python/py/models/builders/__init__.py index b48d1a30e2..cc41f1182e 100644 --- a/src/python/py/models/builders/__init__.py +++ b/src/python/py/models/builders/__init__.py @@ -23,7 +23,7 @@ Phi4MMModel, PhiModel, ) -from .qwen import Qwen3Model, QwenModel +from .qwen import Qwen3Model, Qwen25VLTextModel, QwenModel from .smollm import SmolLM3Model __all__ = [ @@ -48,6 +48,7 @@ "Phi4MMModel", "PhiModel", "Qwen3Model", + "Qwen25VLTextModel", "QwenModel", "SmolLM3Model", ] diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 205cdec652..cb204d13e2 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -472,8 +472,13 @@ def make_rope_init(self, config): "ntk_alpha": beta_slow, "ntk_beta": beta_fast, } + elif "mrope_section" in config.rope_scaling: + # For models that use MRoPE (e.g. Qwen 2.5 VL) + self.rope_attrs["mrope"] = { + "sections": config.rope_scaling["mrope_section"], # Sections for MRoPE + } - def make_attention_init(self): + def is_gqa_supported(self) -> bool: valid_gqa_configurations = { ("cpu", ir.DataType.FLOAT), ("cuda", ir.DataType.FLOAT16), @@ -483,7 +488,10 @@ def make_attention_init(self): ("webgpu", ir.DataType.FLOAT), ("trt-rtx", ir.DataType.FLOAT16), } - if (self.ep, self.io_dtype) in valid_gqa_configurations: + return (self.ep, self.io_dtype) in valid_gqa_configurations + + def make_attention_init(self): + if self.is_gqa_supported(): # Change model settings for GroupQueryAttention self.attention_attrs["op_type"] = "GroupQueryAttention" print("GroupQueryAttention (GQA) is used in this model.") @@ -2684,7 +2692,11 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # O_MatMul # | # O_Add + self.make_attention_input_proj(layer_id, attention, root_input, **kwargs) + self.make_attention_qk_subgraph(layer_id, attention, root_input, **kwargs) + self.make_attention_output_proj(layer_id, attention, root_input, **kwargs) + def make_attention_input_proj(self, layer_id, attention, root_input, **kwargs): # Unpack attention weights if needed self.make_attention_unpacked(layer_id, attention, root_input, **kwargs) @@ -2748,6 +2760,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"]) self.attention_attrs["v_path"] = f"{v_add_name}/output_0" + def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): # Make Q/K SimplifiedLayerNorm nodes if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]: self.make_qk_norm(layer_id, attention) @@ -2809,11 +2822,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): **kwargs, ) + def make_attention_output_proj(self, layer_id, attention, root_input, **kwargs): + attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" + attn_output = f"{attn_name}/output_0" + # Make MatMul node (output projection weight node) o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense" o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" o_weight = getattr(attention, o_proj) - o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") + o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, attn_output) # Make Add node (output projection bias node if bias exists) o_bias_exists = getattr(attention, o_proj).bias is not None @@ -3664,13 +3681,7 @@ def make_layer(self, layer_id, layer): # Norm after last decoder layer of model (last layer --> norm) self.layernorm_attrs["last_layernorm"] = True - def make_model(self, input_path): - # Make inputs and outputs to ONNX model - self.make_inputs_and_outputs() - - # Make pre-processing nodes - self.make_preprocessing_nodes() - + def load_weights(self, input_path): # Load weights of original model if input_path.endswith(".gguf"): # Load GGUF model @@ -3707,7 +3718,6 @@ def make_model(self, input_path): intermediate_size=self.intermediate_size, num_layers=self.num_layers, ) - else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} @@ -3726,6 +3736,17 @@ def make_model(self, input_path): model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token ) + return model + + def make_model(self, input_path): + # Make inputs and outputs to ONNX model + self.make_inputs_and_outputs() + + # Make pre-processing nodes + self.make_preprocessing_nodes() + + model = self.load_weights(input_path) + # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 for module in model.modules(): diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 3c5133e30f..b46de99e30 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -3,10 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from .mistral import MistralModel -class QwenModel(MistralModel): +import onnx_ir as ir +import torch +from transformers import Qwen2_5_VLForConditionalGeneration + +from .base import Model + + +class QwenModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -19,3 +25,655 @@ def make_attention_init(self): self.attention_attrs["q_norm"] = True self.attention_attrs["k_norm"] = True super().make_attention_init() + + +class Qwen25VLTextModel(Model): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + + # The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32. + # By inheriting from `base.Model`, all `layernorm_attrs["cast"]` flags + # are `False`. This causes parity loss and type mismatch error. + # + # SOLUTION: Manually set all `cast` flags to `True`. This forces the + # builder to cast bf16 inputs -> fp32, compute LN, and cast fp32 + # outputs -> bf16, matching the HF model and fixing both errors. + # + print("Forcing LayerNorm computation to float32 (and enabling all casts) for Qwen2.5-VL parity.") + self.layernorm_attrs["cast"]["use_fp32"] = True + self.layernorm_attrs["cast"]["root_input"] = True + self.layernorm_attrs["cast"]["skip_input"] = True + self.layernorm_attrs["cast"]["output_0"] = True + self.layernorm_attrs["cast"]["output_3"] = True + + # Qwen2's RoPE *always* computes in float32. + # We must replicate this behavior. + print("Forcing RoPE computation to float32 for Qwen2.5-VL parity.") + if "rope_cast" not in self.attention_attrs: + self.attention_attrs["rope_cast"] = {} + self.attention_attrs["rope_cast"]["use_fp32"] = True + + # Check rope type since huggingface model supports yarn but that is not recommended as mentioned in model card. Example: + # "rope_scaling": {"type": "mrope", "mrope_section": [16, 24,24]} + if config.rope_scaling and "type" in config.rope_scaling: + assert config.rope_scaling["type"] in ["mrope", "default"] + + # Qwen 2.5 VL applies RoPE manually before attention, not fused in the op + self.attention_attrs["use_rope_in_attn"] = False + + # We need separate Q, K, V tensors to apply MRoPE manually. + # Packed MatMul provides a single output which would require splitting. + self.attention_attrs["use_packed_matmul"] = False + + if "position_ids" not in self.input_names: + print("Re-adding 'position_ids' to self.input_names.") + self.input_names.append("position_ids") + + self.mrope_sections = self.rope_attrs.get("mrope", {}).get("sections", []) + if not self.mrope_sections: + raise ValueError("MRoPE sections not found in config.text_config.rope_scaling.mrope_section") + + # The HF logic is `mrope_section * 2`, not `[s * 2 for s in mrope_section]`. + # This results in [16, 24, 24, 16, 24, 24] + self.mrope_splits = self.mrope_sections * 2 + + if sum(self.mrope_splits) != self.head_size: + # The sum (128) should now correctly match self.head_size (128) + raise ValueError( + f"MRoPE splits {self.mrope_splits} sum ({sum(self.mrope_splits)}) does not match head size ({self.head_size})" + ) + + # Force GroupQueryAttention since make_attention() below only implements GQA. + self.attention_attrs["op_type"] = "GroupQueryAttention" + + if not self.is_gqa_supported(): + print(f"Warning: {self.ep} does not support GQA for {self.io_dtype}, so GQA might fallback to CPU!") + + # Create and save the inv_freq tensor + self.make_inv_freq_tensor() + + def make_inv_freq_tensor(self): + """ + Calculates and saves the `inv_freq` tensor as an initializer. + This is copied from base.py:make_rotary_embedding_caches_from_scratch + """ + dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size) + inv_freq = 1.0 / ( + self.rope_attrs["rescale_factors"] + * (self.rope_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + ) + + # The HF model expects H/2, not R/2 + if dim != self.head_size: + print( + f"Warning: partial_rotary_factor ({self.rope_attrs['partial_rotary_factor']}) is not 1. This might be unsupported." + ) + inv_freq = inv_freq[: (self.head_size // 2)] + + self.make_initializer(inv_freq, "model.inv_freq", to=ir.DataType.FLOAT) + print("Created and saved 'model.inv_freq' initializer.") + + def make_inputs_and_outputs(self): + # Qwen2.5-VL uses 3D position_ids + self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"] + + # Call the base Model's make_inputs_and_outputs (skipping MistralModel's) + super().make_inputs_and_outputs() + + def make_dynamic_rope_caches(self, layer_id, basename): + # Make nodes for the Dynamic RoPE Cache subgraph + # + # Re-implements Qwen2_5_VLRotaryEmbedding.forward using ONNX ops. + # Takes 3D position_ids and inv_freq and dynamically creates + # the cos/sin caches. + # + # inv_freq (H/2) position_ids (3, B, S) + # | | + # Unsqueeze Unsqueeze + # | | + # Expand Cast + # (3, B, H/2, 1) (3, B, 1, S) + # | | + # +--------------------------+---------------------------+ + # | + # MatMul + # (3, B, H/2, S) + # | + # Transpose + # (3, B, S, H/2) + # | + # Concat + # (3, B, S, H) + # | + # +-------------+-------------+ + # | | + # Cos Sin + # | | + # Mul Mul + # (apply scaling) (apply scaling) + # + pos_ids_name = "position_ids" + inv_freq_name = "model.inv_freq" + head_dim_half = self.head_size // 2 + + # Get Batch Size from position_ids.shape[1] + shape_pos_ids_name = f"{basename}/pos_ids/Shape" + shape_pos_ids_output = f"{shape_pos_ids_name}/output_0" + self.make_shape(shape_pos_ids_name, pos_ids_name, [3]) + + gather_batch_size_name = f"{basename}/pos_ids/Gather" + gather_batch_size_output = f"{gather_batch_size_name}/output_0" + self.make_gather( + gather_batch_size_name, + [shape_pos_ids_output, "/model/constants/INT64/[1]"], + ir.DataType.INT64, + [1], + axis=0, + ) + + # Expand inv_freq: [H/2] -> [1, 1, H/2, 1] + unsqueeze_1_name = f"{basename}/inv_freq/Unsqueeze" + unsqueeze_1_output = f"{unsqueeze_1_name}/output_0" + self.make_unsqueeze( + unsqueeze_1_name, + [inv_freq_name, "/model/constants/INT64/[0, 1, 3]"], + ir.DataType.FLOAT, + [1, 1, head_dim_half, 1], + ) + + # Create target shape for Expand: [3, B, H/2, 1] + concat_expand_shape_name = f"{basename}/expand_shape/Concat" + concat_expand_shape_output = f"{concat_expand_shape_name}/output_0" + self.make_concat( + concat_expand_shape_name, + [ + "/model/constants/INT64/[3]", + gather_batch_size_output, + f"/model/constants/INT64/[{head_dim_half}, 1]", + ], + ir.DataType.INT64, + [4], + axis=0, + ) + + expand_name = f"{basename}/inv_freq/Expand" + expand_output = f"{expand_name}/output_0" + self.make_expand( + expand_name, + [unsqueeze_1_output, concat_expand_shape_output], + ir.DataType.FLOAT, + [3, "batch_size", head_dim_half, 1], + ) + + # Expand position_ids: [3, B, S] -> [3, B, 1, S] + unsqueeze_2_name = f"{basename}/pos_ids/Unsqueeze" + unsqueeze_2_output = f"{unsqueeze_2_name}/output_0" + self.make_unsqueeze( + unsqueeze_2_name, + [pos_ids_name, "/model/constants/INT64/[2]"], + ir.DataType.INT64, + [3, "batch_size", 1, "sequence_length"], + ) + + # Cast position_ids to float + cast_name = f"{basename}/pos_ids/Cast" + cast_output = f"{cast_name}/output_0" + self.make_cast( + cast_name, + unsqueeze_2_output, + ir.DataType.FLOAT, + [3, "batch_size", 1, "sequence_length"], + ) + + # MatMul: [3, B, H/2, 1] @ [3, B, 1, S] -> [3, B, H/2, S] + matmul_name = f"{basename}/freqs/MatMul" + matmul_output = f"{matmul_name}/output_0" + self.make_node("MatMul", [expand_output, cast_output], [matmul_output], name=matmul_name) + self.make_value( + matmul_output, + ir.DataType.FLOAT, + [3, "batch_size", head_dim_half, "sequence_length"], + ) + + # Transpose: [3, B, H/2, S] -> [3, B, S, H/2] + transpose_name = f"{basename}/freqs/Transpose" + transpose_output = f"{transpose_name}/output_0" + self.make_transpose( + transpose_name, + matmul_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", head_dim_half], + perm=[0, 1, 3, 2], + ) + + # Concat (freqs, freqs): [3, B, S, H/2] -> [3, B, S, H] + concat_name = f"{basename}/Concat" + concat_output = f"{concat_name}/output_0" + self.make_concat( + concat_name, + [transpose_output, transpose_output], + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + axis=-1, + ) + + # Cos(emb) and Sin(emb) + cos_name = f"{basename}/Cos" + cos_output = f"{cos_name}/output_0" + self.make_node("Cos", [concat_output], [cos_output], name=cos_name) + self.make_value( + cos_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) + + sin_name = f"{basename}/Sin" + sin_output = f"{sin_name}/output_0" + self.make_node("Sin", [concat_output], [sin_output], name=sin_name) + self.make_value( + sin_output, + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size], + ) + + return cos_output, sin_output + + def make_mrope_flattened_caches(self, layer_id, dyn_cos, dyn_sin): + # Converts the 3D MRoPE caches [3, B, S, H] into flattened, interleaved caches [B*S, H/2] + # suitable for the RotaryEmbedding operator. + # The logic is: + # 1. Slice dynamic caches to H/2. + # 2. Split into 3 chunks based on mrope_sections (e.g. 16, 24, 24). + # 3. Gather Temporal(0), Height(1), Width(2) specific slices for each chunk. + # 4. Concat back to H/2. + # 5. Flatten to [B*S, H/2]. + # The subgraph looks like: + # dyn_cos (3, B, S, H) + # | + # Slice + # (3, B, S, H/2) + # | + # Split + # (3, B, S, sections[i]) + # / | \ + # Gather Gather Gather + # idx=0 idx=1 idx=2 + # / | \ + # Squeeze Squeeze Squeeze + # \ | / + # \ | / + # \ | / + # Concat + # (B, S, H/2) + # | + # Reshape + # (B*S, H/2) + + basename = f"/model/layers.{layer_id}/attn/mrope_flattened_cache" + + def process_cache(input_name, name_suffix): + # 1. Slice to H/2: [3, B, S, H] -> [3, B, S, H/2] + slice_name = f"{basename}/{name_suffix}/half/Slice" + slice_output = f"{slice_name}/output_0" + self.make_slice( + slice_name, + [ + input_name, + "/model/constants/INT64/[0]", + f"/model/constants/INT64/[{self.head_size // 2}]", + "/model/constants/INT64/[-1]", + ], + ir.DataType.FLOAT, + [3, "batch_size", "sequence_length", self.head_size // 2], + ) + + # Create a Constant node for mrope_sections: [16, 24, 24] + sections_name = f"{basename}/mrope_sections/Constant" + sections_output = f"{basename}/mrope_sections" + self.make_node( + "Constant", + [], + [sections_output], + name=sections_name, + value=ir.tensor(torch.tensor(self.mrope_sections, dtype=torch.int64), name=sections_output), + ) + self.make_value(sections_output, ir.DataType.INT64, [3]) + + # 2. Split: [3, B, S, H/2] -> 3 * [3, B, S, section_dim] + split_name = f"{basename}/{name_suffix}/Split" + split_outputs = [f"{split_name}/output_{i}" for i in range(3)] + self.make_node( + "Split", + [slice_output, sections_output], + split_outputs, + name=split_name, + axis=-1, + ) + + # 3. Gather + Squeeze: Reorder T, H, W + gathered_chunks = [] + for i in range(3): + # Chunk 0->T(0), Chunk 1->H(1), Chunk 2->W(2) + gather_name = f"{basename}/{name_suffix}/chunk_{i}/Gather" + gather_output = f"{gather_name}/output_0" + self.make_node( + "Gather", + [split_outputs[i], f"/model/constants/INT64/[{i}]"], + [gather_output], + name=gather_name, + axis=0, + ) + # Gather output is [1, B, S, dim] + + squeeze_name = f"{basename}/{name_suffix}/chunk_{i}/Squeeze" + squeeze_output = f"{squeeze_name}/output_0" + self.make_squeeze( + squeeze_name, + [gather_output, "/model/constants/INT64/[0]"], + ir.DataType.FLOAT, + ["batch_size", "sequence_length", self.mrope_sections[i]], + ) + gathered_chunks.append(squeeze_output) + + # 4. Concat: -> [B, S, H/2] + concat_name = f"{basename}/{name_suffix}/Concat" + concat_output = f"{concat_name}/output_0" + self.make_concat( + concat_name, + gathered_chunks, + ir.DataType.FLOAT, + ["batch_size", "sequence_length", self.head_size // 2], + axis=-1, + ) + + # 5. Flatten: -> [B*S, H/2] + reshape_name = f"{basename}/{name_suffix}_flat/Reshape" + reshape_output = f"{reshape_name}/output_0" + self.make_reshape( + reshape_name, + [concat_output, f"/model/constants/INT64/[-1, {self.head_size // 2}]"], + ir.DataType.FLOAT, + ["total_token_count", self.head_size // 2], + ) + return reshape_output + + flat_cos = process_cache(dyn_cos, "cos") + flat_sin = process_cache(dyn_sin, "sin") + + return flat_cos, flat_sin + + def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn_sin, num_heads, basename): + # Make nodes for the MRoPE rotation subgraph using RotaryEmbedding op + # + # 1. Flatten 3D caches [3, B, S, H] -> [B*S, H/2] (via make_mrope_flattened_caches) + # 2. Generate linear position IDs [B, S] (0 .. B*S-1) + # 3. Apply RotaryEmbedding + # + # dyn_cos (3, B, S, H) dyn_sin (3, B, S, H) + # | | + # make_mrope_flattened_caches (slice, split, gather, concat, flatten) + # | | + # flat_cos flat_sin + # (B*S, H/2) (B*S, H/2) + # | | + # +-----------+----------+ + # | + # q_or_k | position_ids + # (B, S, N*H) | (0 .. B*S-1) + # | | | + # Reshape | Reshape + # | | | + # Transpose | | + # (B, N, S, H) | (B, S) + # | | | + # +--------+--------+--------+--------+ + # | | + # RotaryEmbedding (com.microsoft) + # | + # output (B, N, S, H) + # | + # Transpose + # | + # Reshape + # (B, S, N*H) + + # 1. Prepare flattened MRoPE caches [B*S, H/2] + # This slices, splits, and re-assembles the 3D dynamic caches into the correct per-token layout. + flat_cos, flat_sin = self.make_mrope_flattened_caches(layer_id, dyn_cos, dyn_sin) + + # 2. Prepare position_ids [B, S] (values 0 to B*S - 1) + # RotaryEmbedding will use these indices to access the flattened cache. + # Get B*S from q_or_k shape. q_or_k is [B, S, N*H]. + shape_node = f"{basename}/Shape" + self.make_shape(shape_node, q_or_k_path, [3]) + + # Extract B and S + batch_size_node = f"{basename}/BatchSize/Gather" + batch_size_out = f"{batch_size_node}/output_0" + self.make_gather( + batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [], 0 + ) + + seq_len_node = f"{basename}/SeqLen/Gather" + seq_len_out = f"{seq_len_node}/output_0" + self.make_gather( + seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], 0 + ) + + # Calculate Total Tokens = B * S + mul_len_node = f"{basename}/TotalLen/Mul" + mul_len_out = f"{mul_len_node}/output_0" + self.make_node("Mul", [batch_size_out, seq_len_out], [mul_len_out], name=mul_len_node) + self.make_value(mul_len_out, ir.DataType.INT64, []) + + # Range(0, TotalTokens) + range_node = f"{basename}/Range" + range_out = f"{range_node}/output_0" + self.make_node( + "Range", ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"], [range_out], name=range_node + ) + self.make_value(range_out, ir.DataType.INT64, ["total_token_count"]) + + # Slice Position IDs shape from input shape (take first 2 dims) + slice_shape_node = f"{basename}/SliceShape" + slice_shape_out = f"{slice_shape_node}/output_0" + self.make_slice( + slice_shape_node, + [ + f"{shape_node}/output_0", + "/model/constants/INT64/[0]", + "/model/constants/INT64/[2]", + "/model/constants/INT64/[0]", + ], + ir.DataType.INT64, + [2], + ) + + # Reshape Range output to [B, S] + pos_ids_reshape_node = f"{basename}/PosIds/Reshape" + pos_ids_out = f"{pos_ids_reshape_node}/output_0" + self.make_reshape( + pos_ids_reshape_node, [range_out, slice_shape_out], ir.DataType.INT64, ["batch_size", "sequence_length"] + ) + + # 3. Prepare Q/K input [B, N, S, H] + # Input is [B, S, N*H]. Reshape -> [B, S, N, H] -> Transpose -> [B, N, S, H] + reshape_in_node = f"{basename}/Input/Reshape" + reshape_in_out = f"{reshape_in_node}/output_0" + self.make_reshape( + reshape_in_node, + [q_or_k_path, f"/model/constants/INT64/[0, 0, {num_heads}, {self.head_size}]"], + self.io_dtype, + ["batch_size", "sequence_length", num_heads, self.head_size], + ) + + transpose_in_node = f"{basename}/Input/Transpose" + transpose_in_out = f"{transpose_in_node}/output_0" + target_shape_bnsh = ["batch_size", num_heads, "sequence_length", self.head_size] + self.make_transpose(transpose_in_node, reshape_in_out, self.io_dtype, target_shape_bnsh, [0, 2, 1, 3]) + + # 4. Handle Type Casting + # RotaryEmbedding requires input, cos, sin to be same type. + # Qwen2.5-VL forces float32 computation. + force_fp32 = self.attention_attrs.get("rope_cast", {}).get("use_fp32", False) + compute_dtype = ir.DataType.FLOAT if force_fp32 else self.io_dtype + + rope_input = transpose_in_out + if force_fp32 and self.io_dtype != ir.DataType.FLOAT: + cast_in_node = f"{basename}/Input/Cast" + rope_input = f"{cast_in_node}/output_0" + self.make_cast(cast_in_node, transpose_in_out, compute_dtype, target_shape_bnsh) + + rope_cos = flat_cos + rope_sin = flat_sin + # Note: dyn_cos is Float. flat_cos is Float. If compute_dtype is not Float (e.g. fp16), we must cast cache. + if compute_dtype != ir.DataType.FLOAT: + # Cache is Float, we need FP16 + cast_cos_node = f"{basename}/Cos/Cast" + rope_cos = f"{cast_cos_node}/output_0" + self.make_cast(cast_cos_node, flat_cos, compute_dtype, ["total_token_count", self.head_size // 2]) + + cast_sin_node = f"{basename}/Sin/Cast" + rope_sin = f"{cast_sin_node}/output_0" + self.make_cast(cast_sin_node, flat_sin, compute_dtype, ["total_token_count", self.head_size // 2]) + + # 5. RotaryEmbedding Node + rope_node = f"{basename}/RotaryEmbedding" + rope_output = f"{rope_node}/output_0" + self.make_node( + "RotaryEmbedding", + [rope_input, pos_ids_out, rope_cos, rope_sin], + [rope_output], + name=rope_node, + domain="com.microsoft", + rotary_embedding_dim=self.head_size, + num_heads=num_heads, + interleaved=0, # False, matches rotate_half logic + ) + self.make_value(rope_output, compute_dtype, target_shape_bnsh) + + # 6. Post-process Output + # Cast back if needed -> Transpose -> Reshape + final_rope_output = rope_output + if force_fp32 and self.io_dtype != ir.DataType.FLOAT: + cast_out_node = f"{basename}/Output/Cast" + final_rope_output = f"{cast_out_node}/output_0" + self.make_cast(cast_out_node, rope_output, self.io_dtype, target_shape_bnsh) + + transpose_out_node = f"{basename}/Output/Transpose" + transpose_out_out = f"{transpose_out_node}/output_0" + self.make_transpose( + transpose_out_node, + final_rope_output, + self.io_dtype, + ["batch_size", "sequence_length", num_heads, self.head_size], + [0, 2, 1, 3], + ) + + reshape_out_node = f"{basename}/Output/Reshape" + reshape_out_out = f"{reshape_out_node}/output_0" + self.make_reshape( + reshape_out_node, + [transpose_out_out, f"/model/constants/INT64/[0, 0, {num_heads * self.head_size}]"], + self.io_dtype, + q_or_k_shape, + ) + + return reshape_out_out + + def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs): + # Make nodes for the Attention subgraph (with MRoPE) + # + # q_path k_path v_path + # | | | + # | | +-----------------+ + # | | | + # (make_dynamic_rope_caches) | + # | | + # +-----+-----+ | + # | | | + # dyn_cos dyn_sin | + # | | | + # v v | + # (apply_mrope_rotation for Q) | + # | | + # Q_Rot | + # | (apply_mrope_rotation for K) | + # | | | + # | K_Rot | + # | | | + # +--------+--------+ | + # | | + # GroupQueryAttention <--------------+ + # | + + # 1. Calculate shapes for MRoPE rotation + q_shape = [ + "batch_size", + "sequence_length", + self.num_attn_heads * self.head_size, + ] + k_shape = [ + "batch_size", + "sequence_length", + self.num_kv_heads * self.head_size, + ] + + # 2. Apply 3D RoPE (MRoPE) + cos_dynamic, sin_dynamic = self.make_dynamic_rope_caches( + layer_id, basename=f"/model/layers.{layer_id}/attn/mrope_dynamic_cache" + ) + + # Apply rotation to Q + self.attention_attrs["q_path"] = self.apply_mrope_rotation( + layer_id, + self.attention_attrs["q_path"], + q_shape, + cos_dynamic, + sin_dynamic, + self.num_attn_heads, + basename=f"/model/layers.{layer_id}/attn/q_mrope", + ) + + # Apply rotation to K + self.attention_attrs["k_path"] = self.apply_mrope_rotation( + layer_id, + self.attention_attrs["k_path"], + k_shape, + cos_dynamic, + sin_dynamic, + self.num_kv_heads, + basename=f"/model/layers.{layer_id}/attn/k_mrope", + ) + + # 3. Call GroupQueryAttention op + past_k = f"past_key_values.{layer_id}.key" + past_v = f"past_key_values.{layer_id}.value" + present_k = f"present.{layer_id}.key" + present_v = f"present.{layer_id}.value" + + attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" + self.make_attention_op( + attn_name, + q_path=self.attention_attrs["q_path"], + k_path=self.attention_attrs["k_path"], + v_path=self.attention_attrs["v_path"], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + # Pass empty strings for fused caches since we applied RoPE manually + cos_cache="", + sin_cache="", + **kwargs, + ) + + def load_weights(self, input_path): + # Load the Hugging Face model + print("Loading Qwen2_5_VLForConditionalGeneration model...") + return Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_name_or_path, + cache_dir=self.cache_dir, + token=self.hf_token, + trust_remote_code=self.hf_remote, + ) diff --git a/test/python/models/qwen_2.5_vl/run.sh b/test/python/models/qwen_2.5_vl/run.sh new file mode 100644 index 0000000000..6da708b743 --- /dev/null +++ b/test/python/models/qwen_2.5_vl/run.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# This script builds and tests either an fp32, bf16 or fp16 Qwen2.5-VL-3B-Instruct model. Append -f to force export. +# Usage: ./run.sh [fp32|bf16|fp16] [-f] + +# Exit immediately if a command fails +set -e + +# 1. Validate Input +if [ "$1" != "fp32" ] && [ "$1" != "bf16" ] && [ "$1" != "fp16" ]; then + echo "Error: Invalid precision." + echo "Usage: $0 fp32|bf16|fp16" + exit 1 +fi + +# 2. Define variables based on input +PRECISION=$1 +TEST_DIR="$(CDPATH= cd -- "$(dirname -- "$0")" && pwd)" +OUTPUT_DIR="${TEST_DIR}/qwen_${PRECISION}" +ONNX_MODEL_PATH="${OUTPUT_DIR}/model.onnx" +CACHE_DIR="${TEST_DIR}/cache" +HF_MODEL="Qwen/Qwen2.5-VL-3B-Instruct" + +# Set the --bf16 or --fp16 flag for the test script +TEST_FLAG="" +if [ "$PRECISION" == "bf16" ]; then + TEST_FLAG="--bf16" +elif [ "$PRECISION" == "fp16" ]; then + TEST_FLAG="--fp16" +fi + +# 3. Remove output directory only if it exists and -f flag is provided. +if [ "$2" == "-f" ] && [ -d "${OUTPUT_DIR}" ]; then + echo "Removing existing directory: ${OUTPUT_DIR}" + rm -rf "${OUTPUT_DIR}" +fi + +BUILDER_DIR="$(cd ../../../../src/python/py/models && pwd)" + +# 4. Run the builder script if output directory does not exist. +if ! [ -d "${OUTPUT_DIR}" ]; then + echo "--- Building ${PRECISION} model ---" + cd "${BUILDER_DIR}" + python builder.py \ + -m ${HF_MODEL} \ + -p ${PRECISION} \ + -o ${OUTPUT_DIR} \ + -e cuda \ + -c ${CACHE_DIR} +fi + +# 5. Run the parity test +cd "${TEST_DIR}" +echo "--- Testing ${PRECISION} model parity ---" +python test_qwen_2.5_vl.py \ + --hf_model ${HF_MODEL} \ + --cache_dir ${CACHE_DIR} \ + --onnx_model ${ONNX_MODEL_PATH} \ + ${TEST_FLAG} + +echo "--- ${PRECISION} run complete ---" \ No newline at end of file diff --git a/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py new file mode 100644 index 0000000000..3613434766 --- /dev/null +++ b/test/python/models/qwen_2.5_vl/test_qwen_2.5_vl.py @@ -0,0 +1,415 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import argparse + +import numpy as np +import onnxruntime as ort +import torch +from onnx import TensorProto +from transformers import Qwen2_5_VLForConditionalGeneration + + +def torch_dtype_to_onnx_tensor_proto(dtype: torch.dtype) -> int: + """Maps torch.dtype to onnx.TensorProto.DataType""" + if dtype == torch.float32: + return TensorProto.FLOAT + if dtype == torch.float16: + return TensorProto.FLOAT16 + if dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + if dtype == torch.int64: + return TensorProto.INT64 + if dtype == torch.int32: + return TensorProto.INT32 + if dtype == torch.bool: + return TensorProto.BOOL + raise ValueError(f"Unsupported torch dtype: {dtype}") + + +def to_numpy(tensor): + """Move tensor to CPU and convert to numpy, handling bf16.""" + if tensor.dtype == torch.bfloat16: + # NumPy doesn't support bfloat16, so cast to float32 first + return tensor.detach().cpu().to(torch.float32).numpy() + return tensor.detach().cpu().numpy() + + +def compare_outputs( + hf_logits: torch.Tensor, + ort_logits: torch.Tensor, # Changed to torch.Tensor + hf_presents: list[tuple[torch.Tensor, torch.Tensor]], + ort_presents: list[torch.Tensor], # Changed to list[torch.Tensor] + step_name: str, + rtol: float, + atol: float, +): + """Compares logits and KV cache outputs using numpy.""" + + print(f"--- Comparing {step_name} Logits ---") + + # We can use to_numpy safely here because we'll compare fp32 vs fp32 + # or (bf16->fp32) vs (bf16->fp32) + np.testing.assert_allclose(to_numpy(hf_logits), to_numpy(ort_logits), rtol=rtol, atol=atol) + print("Logits: PASS") + + print(f"\n--- Comparing {step_name} KV Cache ---") + # hf_presents is now a list of tuples: [(k0, v0), (k1, v1), ...] + # Flatten it to a list: [k0, v0, k1, v1, ...] + hf_presents_list = [t for layer_kv in hf_presents for t in layer_kv] + + assert len(hf_presents_list) == len(ort_presents), ( + f"HF presents count ({len(hf_presents_list)}) != ORT presents count ({len(ort_presents)})" + ) + + for i in range(len(hf_presents_list)): + hf_tensor = hf_presents_list[i] + ort_tensor = ort_presents[i] + + np.testing.assert_allclose(to_numpy(hf_tensor), to_numpy(ort_tensor), rtol=rtol, atol=atol) + print(f"KV Cache (all {len(hf_presents_list)} tensors): PASS") + print(f"\nāœ… {step_name} Parity Test Passed!\n") + + +def ort_io_binding_helper( + sess: ort.InferenceSession, + input_tensors: dict[str, torch.Tensor], + output_tensors: dict[str, torch.Tensor], + device: str, +) -> None: + """ + Binds torch tensors to an ONNX Runtime IOBinding object and runs the session. + Tensors must be on the correct device (e.g., 'cuda:0'). + """ + bind = sess.io_binding() + + # Get device type and index for ORT + ort_device = device.split(":")[0] + ort_device_id = 0 + if ":" in device: + ort_device_id = int(device.split(":")[1]) + + for name, tensor in input_tensors.items(): + if not tensor.is_contiguous(): + raise RuntimeError(f"Input tensor {name} is not contiguous.") + + bind.bind_input( + name, + ort_device, + ort_device_id, + torch_dtype_to_onnx_tensor_proto(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) + + for name, tensor in output_tensors.items(): + if not tensor.is_contiguous(): + raise RuntimeError(f"Output tensor {name} is not contiguous.") + + bind.bind_output( + name, + ort_device, + ort_device_id, + torch_dtype_to_onnx_tensor_proto(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) + + sess.run_with_iobinding(bind) + + +def test_parity( + hf_model_name: str, + cache_dir: str, + onnx_model_path: str, + use_gpu: bool, + use_bf16: bool, + use_fp16: bool, +): + """ + Runs a two-step (prefill and decode) parity test between the Hugging Face + and ONNX models. + """ + + print(f"Loading Hugging Face model: {hf_model_name}") + print("This requires `trust_remote_code=True`.") + + if not use_gpu: + print("ERROR: This test script now requires a GPU (`--cpu` is not supported) due to IOBinding.") + return + + device = "cuda:0" # IOBinding needs the specific device ID + + if use_bf16: + torch_dtype = torch.bfloat16 + # Standard BF16 tolerances + rtol, atol = 2e-1, 1 + elif use_fp16: + torch_dtype = torch.float16 + # Standard FP16 tolerances + rtol, atol = 1e-1, 5e-1 + else: + torch_dtype = torch.float32 + # Standard FP32 tolerances + rtol, atol = 1e-1, 1e-1 + + # The builder script (base.Model) upcasts logits to float32 + # ONLY when the io_dtype is bfloat16. + # For FP16 or FP32, it keeps the original dtype. + logits_dtype = torch.float32 if use_bf16 else torch_dtype + + print(f"Allocating ONNX logits output buffer with dtype: {logits_dtype}") + + hf_full_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + hf_model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + cache_dir=cache_dir, + ) + .to(device) + .eval() + ) + + # The ONNX model is *only* the language_model component + hf_text_model = hf_full_model.language_model + config = hf_text_model.config + + # Get model parameters + batch_size = 1 + prefill_len = 10 + decode_len = 1 + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + vocab_size = config.vocab_size + + print("\n--- Model Parameters ---") + print(f"Device: {device}") + print(f"DType: {torch_dtype}") + print(f"RTOL: {rtol}, ATOL: {atol}") + print(f"Layers: {num_layers}") + print(f"Hidden Size: {hidden_size}") + print(f"KV Heads: {num_kv_heads}") + print(f"Head Dim: {head_dim}") + print("------------------------\n") + + print(f"Loading ONNX model: {onnx_model_path}") + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + sess = ort.InferenceSession(onnx_model_path, providers=providers) + + # ================================================================= + # 1. PREFILL STEP + # ================================================================= + print(f"Running Prefill Step (Sequence Length = {prefill_len})...") + + # --- Create HF/Torch Inputs --- + # Use randn (normal distribution) scaled down for better stability in FP16 + # inputs_embeds are normally centered around 0, unlike rand which is [0, 1] + inputs_embeds_prefill = ( + torch.randn((batch_size, prefill_len, hidden_size), dtype=torch_dtype, device=device) * 0.001 + ) + + # Qwen2.5-VL uses 3D position IDs (temporal, height, width). + # For text tokens, all three dimensions typically use the same sequence index. + pos_ids_1d_prefill = torch.arange(prefill_len, device=device).expand(batch_size, -1) + position_ids_prefill = pos_ids_1d_prefill.unsqueeze(0).expand(3, -1, -1).contiguous() + + attention_mask_prefill = torch.ones((batch_size, prefill_len), dtype=torch.int64, device=device) + + cache_position_prefill = torch.arange(prefill_len, device=device) + + # --- Create ONNX Input Tensors (on device) --- + ort_inputs_prefill = { + "inputs_embeds": inputs_embeds_prefill, + "position_ids": position_ids_prefill, + "attention_mask": attention_mask_prefill, + } + + # Create dummy pasts with 0 sequence length + past_shape = (batch_size, num_kv_heads, 0, head_dim) + dummy_past = torch.empty(past_shape, dtype=torch_dtype, device=device) + for i in range(num_layers): + ort_inputs_prefill[f"past_key_values.{i}.key"] = dummy_past + ort_inputs_prefill[f"past_key_values.{i}.value"] = dummy_past + + # --- Create ONNX Output Tensors (on device) --- + ort_logits_prefill = torch.empty((batch_size, prefill_len, vocab_size), dtype=logits_dtype, device=device) + ort_presents_prefill = [] + ort_outputs_prefill = {"logits": ort_logits_prefill} + present_shape = (batch_size, num_kv_heads, prefill_len, head_dim) + + for i in range(num_layers): + ort_present_k = torch.empty(present_shape, dtype=torch_dtype, device=device) + ort_present_v = torch.empty(present_shape, dtype=torch_dtype, device=device) + ort_outputs_prefill[f"present.{i}.key"] = ort_present_k + ort_outputs_prefill[f"present.{i}.value"] = ort_present_v + ort_presents_prefill.extend([ort_present_k, ort_present_v]) + + # --- Run HF Model --- + with torch.no_grad(): + hf_outputs_prefill = hf_text_model( + inputs_embeds=inputs_embeds_prefill, + position_ids=position_ids_prefill, + attention_mask=attention_mask_prefill, + past_key_values=None, + cache_position=cache_position_prefill, + return_dict=True, + use_cache=True, + ) + + # --- Run ONNX Model with IOBinding --- + ort_io_binding_helper(sess, ort_inputs_prefill, ort_outputs_prefill, device) + + # --- Compare Prefill --- + hf_logits_prefill = hf_full_model.lm_head(hf_outputs_prefill.last_hidden_state) + hf_presents_prefill = hf_outputs_prefill.past_key_values + + compare_outputs( + hf_logits_prefill, + ort_logits_prefill, # This is the tensor we pre-allocated + hf_presents_prefill, + ort_presents_prefill, # This is the list of tensors we pre-allocated + step_name="Prefill", + rtol=rtol, + atol=atol, + ) + + # ================================================================= + # 2. DECODE STEP + # ================================================================= + print(f"Running Decode Step (Sequence Length = {decode_len})...") + + # --- Create HF/Torch Inputs --- + # Use randn (normal distribution) scaled down + inputs_embeds_decode = torch.randn((batch_size, decode_len, hidden_size), dtype=torch_dtype, device=device) * 0.001 + + # Position IDs continue from prefill length + pos_ids_1d_decode = torch.tensor([[prefill_len]], dtype=torch.int64, device=device) + position_ids_decode = pos_ids_1d_decode.unsqueeze(0).expand(3, -1, -1).contiguous() + + attention_mask_decode = torch.ones((batch_size, prefill_len + decode_len), dtype=torch.int64, device=device) + + cache_position_decode = torch.tensor([prefill_len], device=device) + + # Use the KV cache from the HF prefill run + hf_past_key_values = hf_outputs_prefill.past_key_values + + # --- Create ONNX Input Tensors (on device) --- + ort_inputs_decode = { + "inputs_embeds": inputs_embeds_decode, + "position_ids": position_ids_decode, + "attention_mask": attention_mask_decode, + } + + # Use the KV cache from the ONNX prefill run (these are already torch tensors) + for i in range(num_layers): + ort_inputs_decode[f"past_key_values.{i}.key"] = ort_presents_prefill[i * 2] + ort_inputs_decode[f"past_key_values.{i}.value"] = ort_presents_prefill[i * 2 + 1] + + # --- Create ONNX Output Tensors (on device) --- + # --- FIX: Logits from bf16 ONNX model are intentionally float32 for accuracy --- + ort_logits_decode = torch.empty((batch_size, decode_len, vocab_size), dtype=logits_dtype, device=device) + ort_presents_decode = [] + ort_outputs_decode = {"logits": ort_logits_decode} + present_shape_decode = ( + batch_size, + num_kv_heads, + prefill_len + decode_len, + head_dim, + ) + + for i in range(num_layers): + ort_present_k = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) + ort_present_v = torch.empty(present_shape_decode, dtype=torch_dtype, device=device) + ort_outputs_decode[f"present.{i}.key"] = ort_present_k + ort_outputs_decode[f"present.{i}.value"] = ort_present_v + ort_presents_decode.extend([ort_present_k, ort_present_v]) + + # --- Run HF Model --- + with torch.no_grad(): + hf_outputs_decode = hf_text_model( + inputs_embeds=inputs_embeds_decode, + position_ids=position_ids_decode, + attention_mask=attention_mask_decode, + past_key_values=hf_past_key_values, + cache_position=cache_position_decode, + return_dict=True, + use_cache=True, + ) + + # --- Run ONNX Model with IOBinding --- + ort_io_binding_helper(sess, ort_inputs_decode, ort_outputs_decode, device) + + # --- Compare Decode --- + hf_logits_decode = hf_full_model.lm_head(hf_outputs_decode.last_hidden_state) + hf_presents_decode = hf_outputs_decode.past_key_values + + compare_outputs( + hf_logits_decode, + ort_logits_decode, + hf_presents_decode, + ort_presents_decode, + step_name="Decode", + rtol=rtol, + atol=atol, + ) + + print("=" * 30) + print("šŸŽ‰ All Parity Tests Passed! šŸŽ‰") + print("=" * 30) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Parity test for Qwen 2.5 VL ONNX model.") + parser.add_argument( + "--hf_model", + type=str, + default="Qwen/Qwen2.5-VL-7B-Instruct", + help="Path or name of the Hugging Face model.", + ) + parser.add_argument( + "--onnx_model", + type=str, + required=True, + help="Path to the exported ONNX model file.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default="./qwen2.5_vl_7b_instruct", + help="Path to the cache directory.", + ) + + parser.add_argument( + "--cpu", + action="store_true", + help="Force running the test on CPU (Not supported with IOBinding).", + ) + + parser.add_argument("--bf16", action="store_true", help="Use bf16 precision.") + + parser.add_argument("--fp16", action="store_true", help="Use fp16 precision.") + + args = parser.parse_args() + + if args.cpu and (args.bf16 or args.fp16): + print("Warning: Cannot run bf16/fp16 on CPU. Forcing float32.") + args.bf16 = False + args.fp16 = False + + if args.cpu: + print("Warning: CPU testing with IOBinding is not set up. Forcing GPU.") + # This script is now GPU-only + + test_parity( + hf_model_name=args.hf_model, + cache_dir=args.cache_dir, + onnx_model_path=args.onnx_model, + use_gpu=True, # Forcing GPU + use_bf16=args.bf16, + use_fp16=args.fp16, + )