diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 6f2a87bb64..251391285f 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -429,7 +429,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): }, } - if self.window_size is not None and self.window_size > 0: + if self.ep == "NvTensorRtRtx" and self.window_size is not None and self.window_size > 0: genai_config["model"]["decoder"]["sliding_window"] = {"window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False} if self.ep != "cpu": @@ -2591,31 +2591,27 @@ def make_model(self, input_path): def has_final_norm(self, module, orig_model): # Find where the language model is stored to check attributes. Some classes # store the language model in a different attribute than `model.model`. - if hasattr(orig_model, "language_model"): - # Model is multimodal - # Note: This case is checked first because the `language_model` attribute and the `base_model` attribute - # exist for both multimodal models and PEFT models. However they represent different classes and their attributes - # differ. - model = orig_model.language_model - elif hasattr(orig_model, "base_model") and hasattr(orig_model.base_model, "model"): - if hasattr(orig_model.base_model.model, "model"): - # Model is from PEFT - model = orig_model.base_model.model - else: - # Model is text-based only. - model = orig_model.base_model + if orig_model.__class__.__name__.startswith("Peft"): + # Model is from PEFT + model = orig_model.base_model.model else: model = orig_model - # Hugging Face names + # Hugging Face names (all models loaded with AutoModelForCausalLM.from_pretrained) + # + # hf_norm: for most models + # hf_final_layernorm: for Phi-2 + # hf_transformer_final_layernorm: for ChatGLM-3 + # hf_language_model_norm: for Gemma-3 multimodal (4B, 12B, 27B) hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm + hf_language_model_norm = hasattr(model, "model") and hasattr(model.model, "language_model") and hasattr(model.model.language_model, "norm") and module == model.model.language_model.norm - # GGUF names + # GGUF names (all models loaded with GGUFModel.from_pretrained) gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm - hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm] + hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm, hf_language_model_norm] gguf_names = [gguf_final_norm] return any(hf_names + gguf_names) @@ -3135,7 +3131,6 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.attention_attrs["scale"] = config.query_pre_attn_scalar ** -0.5 self.is_local = lambda layer_id: layer_id % 2 == 1 - def make_layernorm(self, layer_id, layernorm, skip, simple, location): if "final_norm" in location: # Set cast for final LayerNorm since it is a special case and not covered in `make_layer` @@ -3143,7 +3138,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location): super().make_layernorm(layer_id, layernorm, skip, simple, location) def make_layer(self, layer_id, layer): - # Gemma2 decoder layer is typically defined as: + # Gemma-2 decoder layer is typically defined as: # input_layernorm --> attention --> post_attention_layernorm --> pre_ffn_layernorm --> MLP --> post_ffn_layernorm # Adjust LayerNorm attributes because of extra LayerNorms inserted @@ -3600,7 +3595,7 @@ def make_layer(self, layer_id, layer): class Gemma3Model(Gemma2Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.is_local = lambda layer_id: bool((layer_id + 1) % config.sliding_window_pattern) + self.is_local = lambda layer_id: bool((layer_id + 1) % 6) self.rope_local_theta = config.rope_local_base_freq self.make_rotary_embedding_multi_cache() @@ -3809,8 +3804,6 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = Gemma3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) onnx_model.model_type = "gemma3_text" elif config.architectures[0] == "Gemma3ForConditionalGeneration": - print("WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default.") - extra_options["exclude_embeds"] = True text_config = config.text_config for key in text_config: if not hasattr(config, key): @@ -3818,6 +3811,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid if precision == "fp16": print("WARNING: This model loses accuracy with float16 precision. Setting `--precision bf16` by default.") onnx_dtype = ir.DataType.BFLOAT16 + print("WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default.") + extra_options["exclude_embeds"] = True onnx_model = Gemma3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "GraniteForCausalLM": onnx_model = GraniteModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)