Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Phi4MMModel,
PhiModel,
Qwen3Model,
Qwen25VLTextModel,
QwenModel,
SmolLM3Model,
)
Expand Down Expand Up @@ -160,7 +161,15 @@ def set_onnx_dtype(precision: str, extra_options: dict[str, Any]) -> ir.DataType


@torch.no_grad
def create_model(model_name, input_path, output_dir, precision, execution_provider, cache_dir, **extra_options):
def create_model(
model_name,
input_path,
output_dir,
precision,
execution_provider,
cache_dir,
**extra_options,
):
if execution_provider == "NvTensorRtRtx":
execution_provider = "trt-rtx"
extra_options["use_qdq"] = True
Expand All @@ -180,7 +189,10 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
from peft import PeftConfig

peft_config = PeftConfig.from_pretrained(
extra_options["adapter_path"], token=hf_token, trust_remote_code=hf_remote, **extra_kwargs
extra_options["adapter_path"],
token=hf_token,
trust_remote_code=hf_remote,
**extra_kwargs,
)
config.update(peft_config.__dict__)

Expand Down Expand Up @@ -291,6 +303,16 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "SmolLM3ForCausalLM":
onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration":
text_config = config.text_config
for key in text_config:
if not hasattr(config, key):
setattr(config, key, getattr(text_config, key))
print(
"WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default."
)
extra_options["exclude_embeds"] = True
onnx_model = Qwen25VLTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config_only:
# Create base Model class to guess model attributes
onnx_model = Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
Expand Down
3 changes: 2 additions & 1 deletion src/python/py/models/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Phi4MMModel,
PhiModel,
)
from .qwen import Qwen3Model, QwenModel
from .qwen import Qwen3Model, Qwen25VLTextModel, QwenModel
from .smollm import SmolLM3Model

__all__ = [
Expand All @@ -48,6 +48,7 @@
"Phi4MMModel",
"PhiModel",
"Qwen3Model",
"Qwen25VLTextModel",
"QwenModel",
"SmolLM3Model",
]
Loading
Loading