Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Phi4MMModel,
PhiModel,
Qwen3Model,
Qwen25VLTextModel,
QwenModel,
SmolLM3Model,
)
Expand Down Expand Up @@ -161,7 +162,15 @@ def set_onnx_dtype(precision: str, extra_options: dict[str, Any]) -> ir.DataType


@torch.no_grad
def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options):
def create_model(
model_name,
input_path,
output_dir,
precision,
execution_provider,
cache_dir,
**extra_options,
):
if execution_provider == "NvTensorRtRtx":
execution_provider = "trt-rtx"
extra_options["use_qdq"] = True
Expand All @@ -181,7 +190,10 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
from peft import PeftConfig

peft_config = PeftConfig.from_pretrained(
extra_options["adapter_path"], token=hf_token, trust_remote_code=hf_remote, **extra_kwargs
extra_options["adapter_path"],
token=hf_token,
trust_remote_code=hf_remote,
**extra_kwargs,
)
config.update(peft_config.__dict__)

Expand Down Expand Up @@ -292,6 +304,16 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "SmolLM3ForCausalLM":
onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration":
text_config = config.text_config
for key in text_config:
if not hasattr(config, key):
setattr(config, key, getattr(text_config, key))
print(
"WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default."
)
extra_options["exclude_embeds"] = True
onnx_model = Qwen25VLTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config_only:
# Create base Model class to guess model attributes
onnx_model = Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
Expand Down
3 changes: 2 additions & 1 deletion src/python/py/models/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Phi4MMModel,
PhiModel,
)
from .qwen import Qwen3Model, QwenModel
from .qwen import Qwen3Model, Qwen25VLTextModel, QwenModel
from .smollm import SmolLM3Model

__all__ = [
Expand All @@ -48,6 +48,7 @@
"Phi4MMModel",
"PhiModel",
"Qwen3Model",
"Qwen25VLTextModel",
"QwenModel",
"SmolLM3Model",
]
43 changes: 32 additions & 11 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,13 @@
"ntk_alpha": beta_slow,
"ntk_beta": beta_fast,
}
elif "mrope_section" in config.rope_scaling:
# For models that use MRoPE (e.g. Qwen 2.5 VL)
self.rope_attrs["mrope"] = {
"sections": config.rope_scaling["mrope_section"], # Sections for MRoPE
}

def make_attention_init(self):
def is_gqa_supported(self) -> bool:
valid_gqa_configurations = {
("cpu", ir.DataType.FLOAT),
("cuda", ir.DataType.FLOAT16),
Expand All @@ -483,7 +488,10 @@
("webgpu", ir.DataType.FLOAT),
("trt-rtx", ir.DataType.FLOAT16),
}
if (self.ep, self.io_dtype) in valid_gqa_configurations:
return (self.ep, self.io_dtype) in valid_gqa_configurations

def make_attention_init(self):
if self.is_gqa_supported():
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model.")
Expand Down Expand Up @@ -2684,7 +2692,11 @@
# O_MatMul
# |
# O_Add
self.make_attention_input_proj(layer_id, attention, root_input, **kwargs)
self.make_attention_qk_subgraph(layer_id, attention, root_input, **kwargs)
self.make_attention_output_proj(layer_id, attention, root_input, **kwargs)

def make_attention_input_proj(self, layer_id, attention, root_input, **kwargs):
# Unpack attention weights if needed
self.make_attention_unpacked(layer_id, attention, root_input, **kwargs)

Expand Down Expand Up @@ -2748,6 +2760,7 @@
self.make_add_bias(attention.v_proj.bias, v_add_name, root_input=self.attention_attrs["v_path"])
self.attention_attrs["v_path"] = f"{v_add_name}/output_0"

def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs):
# Make Q/K SimplifiedLayerNorm nodes
if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]:
self.make_qk_norm(layer_id, attention)
Expand Down Expand Up @@ -2809,11 +2822,15 @@
**kwargs,
)

def make_attention_output_proj(self, layer_id, attention, root_input, **kwargs):
attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}"
attn_output = f"{attn_name}/output_0"

# Make MatMul node (output projection weight node)
o_proj = "o_proj" if hasattr(attention, "o_proj") else "dense"
o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul"
o_weight = getattr(attention, o_proj)
o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0")
o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, attn_output)

# Make Add node (output projection bias node if bias exists)
o_bias_exists = getattr(attention, o_proj).bias is not None
Expand Down Expand Up @@ -3664,13 +3681,7 @@
# Norm after last decoder layer of model (last layer --> norm)
self.layernorm_attrs["last_layernorm"] = True

def make_model(self, input_path):
# Make inputs and outputs to ONNX model
self.make_inputs_and_outputs()

# Make pre-processing nodes
self.make_preprocessing_nodes()

def load_weights(self, input_path):
# Load weights of original model
if input_path.endswith(".gguf"):
# Load GGUF model
Expand Down Expand Up @@ -3707,7 +3718,6 @@
intermediate_size=self.intermediate_size,
num_layers=self.num_layers,
)

else:
# Load PyTorch model
extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {}
Expand All @@ -3726,6 +3736,17 @@
model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token
)

return model

def make_model(self, input_path):
# Make inputs and outputs to ONNX model
self.make_inputs_and_outputs()

# Make pre-processing nodes
self.make_preprocessing_nodes()

model = self.load_weights(input_path)

# Loop through model and map each module to ONNX/ORT ops
self.layer_id = 0
for module in model.modules():
Expand All @@ -3745,7 +3766,7 @@
elif (
module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock")
) and self.layer_id < self.num_layers:
# Each decoder layer of model

Check warning

Code scanning / CodeQL

Unnecessary delete statement in function Warning

Unnecessary deletion of local variable
model
in function
make_model
.
print(f"Reading decoder layer {self.layer_id}")
self.make_layer(self.layer_id, module)
self.layer_id += 1
Expand Down
Loading
Loading