From 5579b0c5c17fafeaace09cc33015d71cf69ef481 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 13:59:41 +0100 Subject: [PATCH 01/22] draft --- .../configuration_timm_backbone.py | 62 ++++---- .../timm_backbone/modeling_timm_backbone.py | 143 +----------------- .../configuration_timm_wrapper.py | 33 +++- .../timm_wrapper/modeling_timm_wrapper.py | 130 +++++++++++----- .../test_modeling_timm_backbone.py | 2 +- 5 files changed, 166 insertions(+), 204 deletions(-) diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index ba5ed4eb14bf..16d2ad69d46e 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -14,15 +14,14 @@ """Configuration for Backbone models""" -from ...backbone_utils import BackboneConfigMixin -from ...configuration_utils import PreTrainedConfig from ...utils import logging +from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig logger = logging.get_logger(__name__) -class TimmBackboneConfig(BackboneConfigMixin, PreTrainedConfig): +class TimmBackboneConfig(TimmWrapperConfig): r""" This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`]. @@ -69,42 +68,43 @@ def __init__( out_indices=None, freeze_batch_norm_2d=False, output_stride=None, + architecture=None, + do_pooling=False, + model_args=None, **kwargs, ): - self.backbone = backbone - self.num_channels = num_channels - self.features_only = features_only self.out_indices = out_indices if out_indices is not None else [-1] - self.output_stride = output_stride - self.freeze_batch_norm_2d = freeze_batch_norm_2d + architecture = backbone if backbone is not None else architecture + if model_args is None: + model_args = { + "features_only": features_only, + "in_chans": num_channels, + "output_stride": output_stride, + } + super().__init__( + architecture=architecture, + do_pooling=do_pooling, + model_args=model_args, + freeze_batch_norm_2d=freeze_batch_norm_2d, + **kwargs, + ) + + # TODO: just override getattr/setattr instead and add a warning + @property + def backbone(self): + return self.architecture - super().__init__(**kwargs) + @property + def num_channels(self): + return self.model_args["in_chans"] @property - def out_indices(self): - return self._out_indices - - @out_indices.setter - def out_indices(self, out_indices: tuple[int, ...] | list[int]): - """ - Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. - """ - self._out_indices = list(out_indices) if out_indices is not None else out_indices - if getattr(self, "stage_names", None) is not None: - self.set_output_features_output_indices(out_features=None, out_indices=out_indices) + def output_stride(self): + return self.model_args["output_stride"] @property - def out_features(self): - return self._out_features - - @out_features.setter - def out_features(self, out_features: list[str]): - """ - Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. - """ - self._out_features = out_features - if getattr(self, "stage_names", None) is not None: - self.set_output_features_output_indices(out_features=out_features, out_indices=None) + def features_only(self): + return self.model_args["features_only"] __all__ = ["TimmBackboneConfig"] diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index d46796d6f4cd..4041e1879900 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -13,146 +13,17 @@ # limitations under the License. -import torch -from torch import Tensor, nn +from ...utils import logging +from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperBackboneModel -from ... import initialization as init -from ...backbone_utils import BackboneMixin -from ...modeling_outputs import BackboneOutput -from ...modeling_utils import PreTrainedModel -from ...utils import is_timm_available, requires_backends -from .configuration_timm_backbone import TimmBackboneConfig +logger = logging.get_logger(__name__) -if is_timm_available(): - import timm - -class TimmBackbone(BackboneMixin, PreTrainedModel): - """ - Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the - other models in the library keeping the same API. - """ - - main_input_name = "pixel_values" - input_modalities = ("image",) - supports_gradient_checkpointing = False - config: TimmBackboneConfig - - def __init__(self, config, **kwargs): - requires_backends(self, "timm") - - if config.backbone is None: - raise ValueError("backbone is not set in the config. Please set it to a timm model name.") - - # We just take the final layer by default. This matches the default for the transformers models. - out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) - pretrained = kwargs.pop("pretrained", False) - in_chans = kwargs.pop("in_chans", config.num_channels) - - backbone = timm.create_model( - config.backbone, - pretrained=pretrained, - # This is currently not possible for transformer architectures. - features_only=config.features_only, - in_chans=in_chans, - out_indices=out_indices, - output_stride=config.output_stride, - **kwargs, - ) - - # Needs to be called after creating timm model, because `super()` will try to infer - # `stage_names` from model architecture - super().__init__(config, timm_backbone=backbone) - self._backbone = backbone - - # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of - # provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively - if getattr(config, "freeze_batch_norm_2d", False): - self.freeze_batch_norm_2d() - - # These are used to control the output of the model when called. If output_hidden_states is True, then - # return_layers is modified to include all layers. - self._return_layers = { - layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts() - } - self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} - - self.post_init() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - requires_backends(cls, ["vision", "timm"]) - - config = kwargs.pop("config", TimmBackboneConfig()) - num_channels = kwargs.pop("num_channels", config.num_channels) - features_only = kwargs.pop("features_only", config.features_only) - out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmBackboneConfig( - backbone=pretrained_model_name_or_path, - num_channels=num_channels, - features_only=features_only, - out_indices=out_indices, - ) - return super()._from_config(config, pretrained=True, **kwargs) - - def freeze_batch_norm_2d(self): - timm.utils.model.freeze_batch_norm_2d(self._backbone) - - def unfreeze_batch_norm_2d(self): - timm.utils.model.unfreeze_batch_norm_2d(self._backbone) - - @torch.no_grad() - def _init_weights(self, module): - """We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we - assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)""" - if hasattr(module, "init_non_persistent_buffers"): - module.init_non_persistent_buffers() - elif isinstance(module, nn.BatchNorm2d): - # For non-pretrained models, always initialize buffers (handles both meta device and to_empty() cases) - running_mean = getattr(module, "running_mean", None) - if running_mean is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - - def forward( - self, - pixel_values: torch.FloatTensor, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> BackboneOutput | tuple[Tensor, ...]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - - if output_attentions: - raise ValueError("Cannot output attentions for timm backbones at the moment") - - if output_hidden_states: - # We modify the return layers to include all the stages of the backbone - self._backbone.return_layers = self._all_layers - hidden_states = self._backbone(pixel_values, **kwargs) - self._backbone.return_layers = self._return_layers - feature_maps = tuple(hidden_states[i] for i in self.out_indices) - else: - feature_maps = self._backbone(pixel_values, **kwargs) - hidden_states = None - - feature_maps = tuple(feature_maps) - hidden_states = tuple(hidden_states) if hidden_states is not None else None - - if not return_dict: - output = (feature_maps,) - if output_hidden_states: - output = output + (hidden_states,) - return output - - return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) +class TimmBackbone(TimmWrapperBackboneModel): + def __init__(self, *args, **kwargs): + logger.warning("Deprecation message") + super().__init__(*args, **kwargs) __all__ = ["TimmBackbone"] diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 5afd27331d6b..59581988833b 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -16,6 +16,7 @@ from typing import Any +from ...backbone_utils import BackboneConfigMixin from ...configuration_utils import PreTrainedConfig from ...utils import is_timm_available, logging, requires_backends @@ -27,7 +28,7 @@ logger = logging.get_logger(__name__) -class TimmWrapperConfig(PreTrainedConfig): +class TimmWrapperConfig(PreTrainedConfig, BackboneConfigMixin): r""" This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. @@ -46,6 +47,8 @@ class TimmWrapperConfig(PreTrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. do_pooling (`bool`, *optional*, defaults to `True`): Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + freeze_batch_norm_2d (`bool`, *optional*, defaults to `False`): + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. model_args (`dict[str, Any]`, *optional*): Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}` for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`. @@ -69,12 +72,14 @@ def __init__( architecture: str = "resnet50", initializer_range: float = 0.02, do_pooling: bool = True, + freeze_batch_norm_2d=False, model_args: dict[str, Any] | None = None, **kwargs, ): self.architecture = architecture self.initializer_range = initializer_range self.do_pooling = do_pooling + self.freeze_batch_norm_2d = freeze_batch_norm_2d self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) @@ -129,5 +134,31 @@ def to_dict(self) -> dict[str, Any]: output.pop("label2id", None) return output + @property + def out_indices(self): + return self._out_indices + + @out_indices.setter + def out_indices(self, out_indices: tuple[int, ...] | list[int]): + """ + Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. + """ + self._out_indices = list(out_indices) if out_indices is not None else out_indices + if getattr(self, "stage_names", None) is not None: + self.set_output_features_output_indices(out_features=None, out_indices=out_indices) + + @property + def out_features(self): + return self._out_features + + @out_features.setter + def out_features(self, out_features: list[str]): + """ + Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. + """ + self._out_features = out_features + if getattr(self, "stage_names", None) is not None: + self.set_output_features_output_indices(out_features=out_features, out_indices=None) + __all__ = ["TimmWrapperConfig"] diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 6fcfec1389d5..dacd28d14643 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -18,15 +18,18 @@ from torch import Tensor, nn from ... import initialization as init -from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...backbone_utils import BackboneMixin +from ...modeling_outputs import BackboneOutput, ImageClassifierOutput, ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_timm_available, requires_backends +from ...utils import auto_docstring, can_return_tuple, is_timm_available, logging, requires_backends from .configuration_timm_wrapper import TimmWrapperConfig if is_timm_available(): import timm +logger = logging.get_logger(__name__) + @dataclass @auto_docstring( @@ -91,6 +94,12 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): def post_init(self): self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing() + + # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of + # provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively + if getattr(self.config, "freeze_batch_norm_2d", False): + self.freeze_batch_norm_2d() + super().post_init() def load_state_dict(self, state_dict, *args, **kwargs): @@ -142,6 +151,12 @@ def _timm_model_supports_gradient_checkpointing(self): def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs): self.timm_model.set_grad_checkpointing(enable) + def freeze_batch_norm_2d(self): + timm.utils.model.freeze_batch_norm_2d(self.timm_model) + + def unfreeze_batch_norm_2d(self): + timm.utils.model.unfreeze_batch_norm_2d(self.timm_model) + def get_input_embeddings(self): # TIMM backbones operate directly on images and do not expose token embeddings. return None @@ -150,6 +165,74 @@ def set_input_embeddings(self, value): raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.") +class TimmWrapperBackboneModel(BackboneMixin, TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the + other models in the library keeping the same API. + """ + + def __init__(self, config, **kwargs): + requires_backends(self, ["vision", "timm"]) + + extra_init_kwargs = config.model_args or {} + self.features_only = extra_init_kwargs.get("features_only", True) + + # We just take the final layer by default. This matches the default for the transformers models. + out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) + timm_backbone = _create_timm_model_with_error_handling(config, out_indices=out_indices, **extra_init_kwargs) + + # Needs to be called after creating timm model, because `super()` will try to infer + # `stage_names` from model architecture + super().__init__(config, timm_backbone=timm_backbone) + self.timm_model = timm_backbone + + # These are used to control the output of the model when called. If output_hidden_states is True, then + # return_layers is modified to include all layers. + self._return_layers = { + layer["module"]: str(layer["index"]) for layer in self.timm_model.feature_info.get_dicts() + } + self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self.timm_model.feature_info.info)} + + self.post_init() + + @property + def _backbone(self): + logger.warning("Deprecation msg") + return self.timm_model + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BackboneOutput | tuple[Tensor, ...]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot output attentions for timm backbones at the moment") + + if output_hidden_states: + # We modify the return layers to include all the stages of the backbone + self.timm_model.return_layers = self._all_layers + hidden_states = self.timm_model(pixel_values) + self.timm_model.return_layers = self._return_layers + feature_maps = tuple(hidden_states[i] for i in self.out_indices) + else: + feature_maps = self.timm_model(pixel_values) + hidden_states = None + + feature_maps = tuple(feature_maps) + hidden_states = tuple(hidden_states) if hidden_states is not None else None + + return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) + + class TimmWrapperModel(TimmWrapperPreTrainedModel): """ Wrapper class for timm models to be used in transformers. @@ -164,26 +247,17 @@ def __init__(self, config: TimmWrapperConfig): self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() + @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, output_attentions: bool | None = None, output_hidden_states: bool | list[int] | None = None, - return_dict: bool | None = None, do_pooling: bool | None = None, - use_cache: bool | None = None, **kwargs, ) -> TimmWrapperModelOutput | tuple[Tensor, ...]: r""" - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. - do_pooling (`bool`, *optional*): - Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the - `do_pooling` value from the config is used. - Examples: ```python >>> import torch @@ -215,7 +289,6 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state ``` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -236,7 +309,8 @@ def forward( pixel_values = pixel_values.to(self.device) if self.features_only: - last_hidden_state = self.timm_model.forward(pixel_values, **kwargs) + # TODO: ideally features only should be used with `BackboneModel`, deprecate here! + last_hidden_state = self.timm_model.forward(pixel_values) hidden_states = last_hidden_state if output_hidden_states else None pooler_output = None else: @@ -255,11 +329,6 @@ def forward( else: pooler_output = None - if not return_dict: - outputs = (last_hidden_state, pooler_output, hidden_states) - outputs = tuple(output for output in outputs if output is not None) - return outputs - return TimmWrapperModelOutput( last_hidden_state=last_hidden_state, pooler_output=pooler_output, @@ -290,6 +359,7 @@ def __init__(self, config: TimmWrapperConfig): self.num_labels = config.num_labels self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -297,7 +367,6 @@ def forward( labels: torch.LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | list[int] | None = None, - return_dict: bool | None = None, **kwargs, ) -> ImageClassifierOutput | tuple[Tensor, ...]: r""" @@ -305,14 +374,6 @@ def forward( Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - **kwargs: - Additional keyword arguments passed along to the `timm` model forward. Examples: ```python @@ -342,7 +403,6 @@ def forward( >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) ``` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -375,11 +435,6 @@ def forward( if labels is not None: loss = self.loss_function(labels, logits, self.config) - if not return_dict: - outputs = (loss, logits, hidden_states) - outputs = tuple(output for output in outputs if output is not None) - return outputs - return ImageClassifierOutput( loss=loss, logits=logits, @@ -387,4 +442,9 @@ def forward( ) -__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"] +__all__ = [ + "TimmWrapperBackboneModel", + "TimmWrapperPreTrainedModel", + "TimmWrapperModel", + "TimmWrapperForImageClassification", +] diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index fa8b24a40aa7..9ab71fa94a50 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -91,7 +91,7 @@ def setUp(self): self.config_class = TimmBackboneConfig self.model_tester = TimmBackboneModelTester(self) self.config_tester = ConfigTester( - self, config_class=self.config_class, has_text_modality=False, common_properties=["num_channels"] + self, config_class=self.config_class, has_text_modality=False, common_properties=["architecture"] ) def test_config(self): From a5dbd5fb189a14378352affb24ee22205a58d27a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 15:44:16 +0100 Subject: [PATCH 02/22] fix the test --- .../configuration_timm_backbone.py | 38 +++++++++++-------- .../timm_wrapper/modeling_timm_wrapper.py | 1 + 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index 16d2ad69d46e..f0a932086717 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -59,6 +59,12 @@ class TimmBackboneConfig(TimmWrapperConfig): """ model_type = "timm_backbone" + special_attribute_map = { + "backbone": "architecture", + "num_channels": ("model_args", "in_chans"), + "output_stride": ("model_args", "output_stride"), + "features_only": ("model_args", "features_only"), + } def __init__( self, @@ -89,22 +95,22 @@ def __init__( **kwargs, ) - # TODO: just override getattr/setattr instead and add a warning - @property - def backbone(self): - return self.architecture - - @property - def num_channels(self): - return self.model_args["in_chans"] - - @property - def output_stride(self): - return self.model_args["output_stride"] - - @property - def features_only(self): - return self.model_args["features_only"] + def __setattr__(self, key, value): + if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: + first_attr = self + if isinstance(mapped_key, (tuple, list)): + first_attr = super().__getattribute__("__dict__").get(mapped_key[0]) + setattr(first_attr, mapped_key[1], value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key): + if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: + first_attr = self + if isinstance(mapped_key, (tuple, list)): + first_attr = super().__getattribute__(mapped_key[0]) + return getattr(first_attr, mapped_key[1]) + return super().__getattribute__(key) __all__ = ["TimmBackboneConfig"] diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index dacd28d14643..8e04e13d16d7 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -255,6 +255,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | list[int] | None = None, do_pooling: bool | None = None, + use_cache: bool | None = None, **kwargs, ) -> TimmWrapperModelOutput | tuple[Tensor, ...]: r""" From 5c7eb1048812c5f47cd02248e143ec62fd80a40d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 16:53:07 +0100 Subject: [PATCH 03/22] delete from auto-map and redirect to TimmWrapper --- src/transformers/backbone_utils.py | 6 ++-- src/transformers/models/auto/auto_factory.py | 12 ++++--- .../models/auto/configuration_auto.py | 2 +- src/transformers/models/auto/modeling_auto.py | 5 ++- .../configuration_timm_backbone.py | 32 ++++++++----------- .../configuration_timm_wrapper.py | 10 +++++- 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index 30a3a140b783..32ab29f7386e 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -330,12 +330,12 @@ def load_backbone(config): If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights if specified. """ - from transformers import AutoBackbone + from transformers import AutoModel backbone_config = getattr(config, "backbone_config", None) if backbone_config is None: - backbone = AutoBackbone.from_config(config=config) + backbone = AutoModel.from_config(config=config) else: - backbone = AutoBackbone.from_config(config=backbone_config) + backbone = AutoModel.from_config(config=backbone_config) return backbone diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index af9e0e569349..edc6e67cc878 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -407,6 +407,7 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass): def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): requires_backends(cls, ["vision", "timm"]) from ...models.timm_backbone import TimmBackboneConfig + from ...models.timm_wrapper import TimmWrapperConfig config = kwargs.pop("config", TimmBackboneConfig()) @@ -419,11 +420,14 @@ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *mod num_channels = kwargs.pop("num_channels", config.num_channels) features_only = kwargs.pop("features_only", config.features_only) out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmBackboneConfig( - backbone=pretrained_model_name_or_path, - num_channels=num_channels, - features_only=features_only, + config = TimmWrapperConfig( + architecture=pretrained_model_name_or_path, + do_pooling=False, out_indices=out_indices, + model_args={ + "in_chans": num_channels, + "features_only": features_only, + }, ) # Always load a pretrained model when `from_pretrained` is called kwargs.pop("use_pretrained_backbone", None) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 18f8c632182a..ffdfa5dc7fc7 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -443,7 +443,7 @@ ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), - ("timm_backbone", "TimmBackboneConfig"), + ("timm_backbone", "TimmWrapperConfig"), # for BC ("timm_wrapper", "TimmWrapperConfig"), ("trocr", "TrOCRConfig"), ("tvp", "TvpConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 952ff1da2bfa..23b52dc8f16e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -426,7 +426,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("time_series_transformer", "TimeSeriesTransformerModel"), ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), - ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), ("tvp", "TvpModel"), ("udop", "UdopModel"), @@ -776,7 +775,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("swinv2", "Swinv2Model"), ("table-transformer", "TableTransformerModel"), ("timesformer", "TimesformerModel"), - ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), ("videomae", "VideoMAEModel"), ("vit", "ViTModel"), @@ -1641,7 +1639,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("swin", "SwinBackbone"), ("swinv2", "Swinv2Backbone"), ("textnet", "TextNetBackbone"), - ("timm_backbone", "TimmBackbone"), + ("timm_backbone", "TimmWrapperBackboneModel"), # for BC + ("timm_wrapper", "TimmWrapperBackboneModel"), ("vitdet", "VitDetBackbone"), ("vitpose_backbone", "VitPoseBackbone"), ] diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index f0a932086717..d4a9b0263cce 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -74,42 +74,36 @@ def __init__( out_indices=None, freeze_batch_norm_2d=False, output_stride=None, - architecture=None, - do_pooling=False, - model_args=None, **kwargs, ): self.out_indices = out_indices if out_indices is not None else [-1] - architecture = backbone if backbone is not None else architecture - if model_args is None: - model_args = { + kwargs["architecture"] = backbone if backbone is not None else kwargs.get("architecture") + kwargs["do_pooling"] = False # hardcode for backbone model + if kwargs.get("model_args") is None: + kwargs["model_args"] = { "features_only": features_only, "in_chans": num_channels, "output_stride": output_stride, } - super().__init__( - architecture=architecture, - do_pooling=do_pooling, - model_args=model_args, - freeze_batch_norm_2d=freeze_batch_norm_2d, - **kwargs, - ) + logger.warning("Deprecation message!") + super().__init__(freeze_batch_norm_2d=freeze_batch_norm_2d, **kwargs) def __setattr__(self, key, value): if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: - first_attr = self if isinstance(mapped_key, (tuple, list)): - first_attr = super().__getattribute__("__dict__").get(mapped_key[0]) - setattr(first_attr, mapped_key[1], value) + model_args = super().__getattribute__("__dict__").get(mapped_key[0]) + model_args[mapped_key[1]] = value + else: + setattr(self, mapped_key[1], value) else: super().__setattr__(key, value) def __getattribute__(self, key): if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: - first_attr = self if isinstance(mapped_key, (tuple, list)): - first_attr = super().__getattribute__(mapped_key[0]) - return getattr(first_attr, mapped_key[1]) + model_args = super().__getattribute__(mapped_key[0]) + return model_args[mapped_key[1]] + return getattr(self, mapped_key[1]) return super().__getattribute__(key) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 59581988833b..b66df35c202e 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -76,11 +76,19 @@ def __init__( model_args: dict[str, Any] | None = None, **kwargs, ): - self.architecture = architecture + is_backbone_config = kwargs.get("backbone") is not None + self.architecture = kwargs.pop("backbone") if is_backbone_config else architecture self.initializer_range = initializer_range self.do_pooling = do_pooling self.freeze_batch_norm_2d = freeze_batch_norm_2d self.model_args = model_args # named "model_args" for BC with timm + if model_args is None and is_backbone_config: + model_args = { + "features_only": kwargs.pop("features_only", True), + "in_chans": kwargs.pop("num_channels", 3), + "output_stride": kwargs.get("output_stride"), + } + super().__init__(**kwargs) @classmethod From 4303988dd8ffe751821903d5f734d8d189db513e Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 17:28:26 +0100 Subject: [PATCH 04/22] what if? --- src/transformers/backbone_utils.py | 8 +- src/transformers/models/auto/auto_factory.py | 75 ++++++++----------- src/transformers/models/auto/modeling_auto.py | 3 +- 3 files changed, 33 insertions(+), 53 deletions(-) diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index 32ab29f7386e..cc50aa6a6100 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -330,12 +330,8 @@ def load_backbone(config): If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights if specified. """ - from transformers import AutoModel + from transformers import AutoBackbone backbone_config = getattr(config, "backbone_config", None) - - if backbone_config is None: - backbone = AutoModel.from_config(config=config) - else: - backbone = AutoModel.from_config(config=backbone_config) + backbone = AutoBackbone.from_config(config=backbone_config) return backbone diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index edc6e67cc878..166c0b2190a4 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -21,8 +21,6 @@ from collections.abc import Iterator from typing import Any, TypeVar -from huggingface_hub import repo_exists - from ...configuration_utils import PreTrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...utils import ( @@ -247,8 +245,38 @@ def _prepare_config_for_auto_class(cls, config: PreTrainedConfig) -> PreTrainedC """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" return config + @classmethod + def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_wrapper import TimmWrapperConfig + + if kwargs.get("output_loading_info", False): + raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") + + # Users can't pass `config` and `kwargs`, choose only one! + config = kwargs.pop("config", None) + if config is None: + config = TimmWrapperConfig( + architecture=pretrained_model_name_or_path, + do_pooling=False, + out_indices=kwargs.pop("out_indices", (-1,)), + model_args={ + "in_chans": kwargs.pop("num_channels", 3), + "features_only": kwargs.pop("features_only", True), + }, + ) + + # Always load a pretrained model when `from_pretrained` is called + kwargs.pop("use_pretrained_backbone", None) + return super().from_config(config, pretrained=True, **kwargs) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args, **kwargs): + # Early exit for `timm` models, they aren't hosted on the hub usually + use_timm_backbone = kwargs.pop("use_timm_backbone", None) + if use_timm_backbone: + return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) trust_remote_code = kwargs.get("trust_remote_code") kwargs["_from_auto"] = True @@ -399,49 +427,6 @@ def register(cls, config_class, model_class, exist_ok=False) -> None: cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) -class _BaseAutoBackboneClass(_BaseAutoModelClass): - # Base class for auto backbone models. - _model_mapping = None - - @classmethod - def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - requires_backends(cls, ["vision", "timm"]) - from ...models.timm_backbone import TimmBackboneConfig - from ...models.timm_wrapper import TimmWrapperConfig - - config = kwargs.pop("config", TimmBackboneConfig()) - - if kwargs.get("out_features") is not None: - raise ValueError("Cannot specify `out_features` for timm backbones") - - if kwargs.get("output_loading_info", False): - raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") - - num_channels = kwargs.pop("num_channels", config.num_channels) - features_only = kwargs.pop("features_only", config.features_only) - out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmWrapperConfig( - architecture=pretrained_model_name_or_path, - do_pooling=False, - out_indices=out_indices, - model_args={ - "in_chans": num_channels, - "features_only": features_only, - }, - ) - # Always load a pretrained model when `from_pretrained` is called - kwargs.pop("use_pretrained_backbone", None) - return super().from_config(config, pretrained=True, **kwargs) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - kwargs.pop("use_timm_backbone", None) - if not repo_exists(pretrained_model_name_or_path): - return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - def insert_head_doc(docstring, head_doc: str = ""): if len(head_doc) > 0: return docstring.replace( diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 23b52dc8f16e..9137e28ababa 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -19,7 +19,6 @@ from ...utils import logging from .auto_factory import ( - _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update, @@ -2154,7 +2153,7 @@ class AutoModelForTextToWaveform(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING -class AutoBackbone(_BaseAutoBackboneClass): +class AutoBackbone(_BaseAutoModelClass): _model_mapping = MODEL_FOR_BACKBONE_MAPPING From 217070eecfa8775111222ea2482c9bd217993d8c Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 17:55:47 +0100 Subject: [PATCH 05/22] stupid error --- .../models/timm_wrapper/configuration_timm_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index b66df35c202e..55250fb90625 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -81,7 +81,6 @@ def __init__( self.initializer_range = initializer_range self.do_pooling = do_pooling self.freeze_batch_norm_2d = freeze_batch_norm_2d - self.model_args = model_args # named "model_args" for BC with timm if model_args is None and is_backbone_config: model_args = { "features_only": kwargs.pop("features_only", True), @@ -89,6 +88,7 @@ def __init__( "output_stride": kwargs.get("output_stride"), } + self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) @classmethod From 55e9723e85c3fcf7efd0afc5a4b310c056cd66c7 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:02:13 +0100 Subject: [PATCH 06/22] modeling code --- src/transformers/models/auto/configuration_auto.py | 2 +- src/transformers/models/vitmatte/modeling_vitmatte.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index ffdfa5dc7fc7..2757b8ebdd11 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -443,7 +443,7 @@ ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), - ("timm_backbone", "TimmWrapperConfig"), # for BC + ("timm_backbone", "TimmBackboneConfig"), # for BC ("timm_wrapper", "TimmWrapperConfig"), ("trocr", "TrOCRConfig"), ("tvp", "TvpConfig"), diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 658d90e8aa85..71d82f46ec01 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -107,7 +107,11 @@ def __init__(self, config): # to enable loading HF backbone models. in_channels = 4 if config.backbone_config is not None: - in_channels = config.backbone_config.num_channels + in_channels = ( + config.backbone_config.num_channels + if hasattr(config.backbone_config, "num_channels") + else config.backbone_config.model_args["in_chans"] + ) out_channels = config.convstream_hidden_sizes From f7a0fece46f4420e04e9af0c918d14e6d255cbd9 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:05:00 +0100 Subject: [PATCH 07/22] delete the utility fn --- src/transformers/backbone_utils.py | 17 ----------------- .../modeling_conditional_detr.py | 4 ++-- .../models/d_fine/modeling_d_fine.py | 4 ++-- .../models/dab_detr/modeling_dab_detr.py | 4 ++-- .../deformable_detr/modeling_deformable_detr.py | 4 ++-- .../deformable_detr/modular_deformable_detr.py | 4 ++-- .../depth_anything/modeling_depth_anything.py | 4 ++-- src/transformers/models/detr/modeling_detr.py | 4 ++-- src/transformers/models/dpt/modeling_dpt.py | 6 +++--- .../grounding_dino/modeling_grounding_dino.py | 4 ++-- .../models/mask2former/modeling_mask2former.py | 4 ++-- .../models/maskformer/modeling_maskformer.py | 4 ++-- .../modeling_mm_grounding_dino.py | 4 ++-- .../models/omdet_turbo/modeling_omdet_turbo.py | 4 ++-- .../models/oneformer/modeling_oneformer.py | 4 ++-- .../pp_doclayout_v3/modeling_pp_doclayout_v3.py | 4 ++-- .../modeling_prompt_depth_anything.py | 4 ++-- .../models/rt_detr/modeling_rt_detr.py | 4 ++-- .../models/rt_detr/modular_rt_detr.py | 4 ++-- .../models/rt_detr_v2/modeling_rt_detr_v2.py | 4 ++-- .../modeling_table_transformer.py | 4 ++-- src/transformers/models/tvp/modeling_tvp.py | 4 ++-- .../models/upernet/modeling_upernet.py | 4 ++-- .../models/vitmatte/modeling_vitmatte.py | 4 ++-- .../models/vitpose/modeling_vitpose.py | 4 ++-- .../models/zoedepth/modeling_zoedepth.py | 4 ++-- 26 files changed, 51 insertions(+), 68 deletions(-) diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index cc50aa6a6100..1d3a3501eeab 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -318,20 +318,3 @@ def consolidate_backbone_kwargs_to_config( backbone_config = config_class.from_dict(backbone_config) return backbone_config, kwargs - - -def load_backbone(config): - """ - Loads the backbone model from a config object. - - If the config is from the backbone model itself, then we return a backbone model with randomly initialized - weights. - - If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights - if specified. - """ - from transformers import AutoBackbone - - backbone_config = getattr(config, "backbone_config", None) - backbone = AutoBackbone.from_config(config=backbone_config) - return backbone diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 2de83de19c12..266da4d9ca08 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -26,7 +26,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -258,7 +258,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 1c758f8b1dcd..4ffe90a46bb7 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...activations import ACT2CLS -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -1351,7 +1351,7 @@ class DFineConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 421bdbce6b89..85ce946c0d49 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -21,7 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -211,7 +211,7 @@ def __init__(self, config: DabDetrConfig): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 3ee685a887c1..a0f6e9402984 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...integrations import use_kernel_forward_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions @@ -298,7 +298,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm diff --git a/src/transformers/models/deformable_detr/modular_deformable_detr.py b/src/transformers/models/deformable_detr/modular_deformable_detr.py index dfbc0783fb0a..ffd378c17d30 100644 --- a/src/transformers/models/deformable_detr/modular_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modular_deformable_detr.py @@ -21,7 +21,7 @@ from torch import Tensor from ... import initialization as init -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -292,7 +292,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 16e1e3c0319c..25f641e81cbb 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -16,7 +16,7 @@ import torch from torch import nn -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -319,7 +319,7 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.neck = DepthAnythingNeck(config) self.head = DepthAnythingDepthEstimationHead(config) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 4906b3510f44..c0e73a37063b 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -258,7 +258,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index ac8b255bfecd..f571604fd63b 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -102,7 +102,7 @@ def __init__(self, config: DPTConfig, feature_size: tuple[int, int] | None = Non patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) feature_dim = self.backbone.channels[-1] if len(self.backbone.channels) != 3: raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}") @@ -925,7 +925,7 @@ def __init__(self, config): self.backbone = None if config.is_hybrid is False and config.backbone_config is not None: - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) else: self.dpt = DPTModel(config, add_pooling_layer=False) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 26952ada4894..6bfb2903aa74 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel @@ -368,7 +368,7 @@ def __init__(self, config): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 3d6a2acd0968..e06b90ff16c0 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions @@ -1399,7 +1399,7 @@ def __init__(self, config: Mask2FormerConfig): """ super().__init__() - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index cacb86b788ed..f761d7efdb96 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions @@ -1352,7 +1352,7 @@ def __init__(self, config: MaskFormerConfig): backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] config.backbone_config = backbone_config - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) feature_channels = self.encoder.channels self.decoder = MaskFormerPixelDecoder( diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 6acaccf70023..8f8715eb3ed7 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel @@ -646,7 +646,7 @@ def __init__(self, config): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index a30b52897051..ecfded0ffe68 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -26,7 +26,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -279,7 +279,7 @@ class OmDetTurboVisionBackbone(nn.Module): def __init__(self, config: OmDetTurboConfig): super().__init__() self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone - self.vision_backbone = load_backbone(config) + self.vision_backbone = AutoBackbone.from_config(config=config.backbone_config) self.layer_norms = nn.ModuleList( [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels] ) diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 2b4433e52fa9..ae00809eaece 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -24,7 +24,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -1456,7 +1456,7 @@ def __init__(self, config: OneFormerConfig): The configuration used to instantiate this model. """ super().__init__() - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: diff --git a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py index d4c275b93eed..f54dea229a71 100644 --- a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +++ b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py @@ -30,7 +30,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -1322,7 +1322,7 @@ class PPDocLayoutV3ConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 5461d87c609a..c84b35e476e4 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring @@ -375,7 +375,7 @@ class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.neck = PromptDepthAnythingNeck(config) self.head = PromptDepthAnythingDepthEstimationHead(config) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 182d4b2c054a..f5c799b4115f 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -400,7 +400,7 @@ class RTDetrConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index f9289f9e6619..b0b771c48678 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width from ...image_transforms import center_to_corners_format, corners_to_center_format @@ -689,7 +689,7 @@ class RTDetrConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index b5244ffda7f8..3921dd0791d1 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -796,7 +796,7 @@ class RTDetrV2ConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 68c66ce8248b..b39c296d9645 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -21,7 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -204,7 +204,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 917556c31fff..04aec13d06ed 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -21,7 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel @@ -135,7 +135,7 @@ def forward(self, logits, labels): class TvpVisionModel(nn.Module): def __init__(self, config): super().__init__() - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) if config.backbone_config is not None: in_channels = config.backbone_config.hidden_sizes[-1] diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index bf497134646f..4217c2c5b832 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -17,7 +17,7 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring @@ -279,7 +279,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) # Semantic segmentation head(s) self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 71d82f46ec01..b4497c459993 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -19,7 +19,7 @@ from torch import nn from ... import initialization as init -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring from .configuration_vitmatte import VitMatteConfig @@ -226,7 +226,7 @@ def __init__(self, config): super().__init__(config) self.config = config - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.decoder = VitMatteDetailCaptureModule(config) # Initialize weights and apply final processing diff --git a/src/transformers/models/vitpose/modeling_vitpose.py b/src/transformers/models/vitpose/modeling_vitpose.py index 3e47f66bd03e..66dcfb8d9f2a 100644 --- a/src/transformers/models/vitpose/modeling_vitpose.py +++ b/src/transformers/models/vitpose/modeling_vitpose.py @@ -19,7 +19,7 @@ from torch import nn from ... import initialization as init -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack @@ -191,7 +191,7 @@ class VitPoseForPoseEstimation(VitPosePreTrainedModel): def __init__(self, config: VitPoseConfig): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) # add backbone attributes if not hasattr(self.backbone.config, "hidden_size"): diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index d385ca4080c2..8cbb747d5209 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -21,7 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone +from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -1226,7 +1226,7 @@ class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) if hasattr(self.backbone.config, "hidden_size") and hasattr(self.backbone.config, "patch_size"): config.backbone_hidden_size = self.backbone.config.hidden_size From abf4993b5ba33adf58a146a57a6163e21086ac91 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:09:18 +0100 Subject: [PATCH 08/22] oh no, not the corect path --- .../models/conditional_detr/modeling_conditional_detr.py | 2 +- src/transformers/models/d_fine/modeling_d_fine.py | 2 +- src/transformers/models/dab_detr/modeling_dab_detr.py | 2 +- .../models/deformable_detr/modeling_deformable_detr.py | 2 +- .../models/deformable_detr/modular_deformable_detr.py | 2 +- .../models/depth_anything/modeling_depth_anything.py | 2 +- src/transformers/models/detr/modeling_detr.py | 2 +- src/transformers/models/dpt/modeling_dpt.py | 2 +- .../models/grounding_dino/modeling_grounding_dino.py | 3 +-- src/transformers/models/mask2former/modeling_mask2former.py | 2 +- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- .../models/mm_grounding_dino/modeling_mm_grounding_dino.py | 2 +- src/transformers/models/omdet_turbo/modeling_omdet_turbo.py | 3 +-- src/transformers/models/oneformer/modeling_oneformer.py | 2 +- .../models/pp_doclayout_v3/modeling_pp_doclayout_v3.py | 2 +- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 2 +- src/transformers/models/rt_detr/modeling_rt_detr.py | 2 +- src/transformers/models/rt_detr/modular_rt_detr.py | 2 +- src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py | 2 +- .../models/table_transformer/modeling_table_transformer.py | 2 +- src/transformers/models/tvp/modeling_tvp.py | 2 +- src/transformers/models/upernet/modeling_upernet.py | 2 +- src/transformers/models/vitmatte/modeling_vitmatte.py | 2 +- src/transformers/models/vitpose/modeling_vitpose.py | 2 +- src/transformers/models/zoedepth/modeling_zoedepth.py | 2 +- 25 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 266da4d9ca08..3d3d81450fb9 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -26,7 +26,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -36,6 +35,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from .configuration_conditional_detr import ConditionalDetrConfig diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 4ffe90a46bb7..4614cab6f237 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2CLS -from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -37,6 +36,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_d_fine import DFineConfig diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 85ce946c0d49..a721563882db 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -21,7 +21,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -31,6 +30,7 @@ auto_docstring, logging, ) +from ..auto import AutoBackbone from .configuration_dab_detr import DabDetrConfig diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index a0f6e9402984..82bb58b200ae 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -29,7 +29,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...integrations import use_kernel_forward_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions @@ -39,6 +38,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from .configuration_deformable_detr import DeformableDetrConfig diff --git a/src/transformers/models/deformable_detr/modular_deformable_detr.py b/src/transformers/models/deformable_detr/modular_deformable_detr.py index ffd378c17d30..886302bca645 100644 --- a/src/transformers/models/deformable_detr/modular_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modular_deformable_detr.py @@ -21,7 +21,6 @@ from torch import Tensor from ... import initialization as init -from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -37,6 +36,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from ..detr.image_processing_detr_fast import DetrImageProcessorFast from ..detr.modeling_detr import ( DetrConvEncoder, diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 25f641e81cbb..7a53e57b95b6 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -16,10 +16,10 @@ import torch from torch import nn -from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ..auto import AutoBackbone from .configuration_depth_anything import DepthAnythingConfig diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index c0e73a37063b..ed50792d9e7b 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -22,7 +22,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -41,6 +40,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_detr import DetrConfig diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index f571604fd63b..e9f96d148d76 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -36,6 +35,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_dpt import DPTConfig diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 6bfb2903aa74..2f53024db563 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -23,12 +23,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging, torch_compilable_check -from ..auto import AutoModel +from ..auto import AutoBackbone, AutoModel from .configuration_grounding_dino import GroundingDinoConfig diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e06b90ff16c0..8e4f19d05cc9 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,13 +23,13 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import auto_docstring, is_accelerate_available, logging, torch_compilable_check +from ..auto import AutoBackbone from .configuration_mask2former import Mask2FormerConfig diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index f761d7efdb96..1ad4e0417e98 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -23,7 +23,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions @@ -39,6 +38,7 @@ logging, requires_backends, ) +from ..auto import AutoBackbone from ..detr import DetrConfig from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer_swin import MaskFormerSwinConfig diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 8f8715eb3ed7..18e8404f0231 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -27,11 +27,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, torch_compilable_check +from ..auto import AutoBackbone from ..auto.modeling_auto import AutoModel from .configuration_mm_grounding_dino import MMGroundingDinoConfig diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index ecfded0ffe68..fb172a009642 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -26,7 +26,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...auto import AutoBackbone from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -39,7 +38,7 @@ logging, torch_compilable_check, ) -from ..auto import AutoModel +from ..auto import AutoBackbone, AutoModel from .configuration_omdet_turbo import OmDetTurboConfig diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index ae00809eaece..feb959e86b14 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -24,7 +24,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -39,6 +38,7 @@ torch_compilable_check, ) from ...utils.generic import maybe_autocast +from ..auto import AutoBackbone from .configuration_oneformer import OneFormerConfig diff --git a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py index f54dea229a71..e5ee32f818bf 100644 --- a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +++ b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py @@ -30,7 +30,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -46,6 +45,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_pp_doclayout_v3 import PPDocLayoutV3Config diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index c84b35e476e4..bca80f173dd0 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn -from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring from ...utils.generic import torch_int +from ..auto import AutoBackbone from .configuration_prompt_depth_anything import PromptDepthAnythingConfig diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index f5c799b4115f..ec9aa7fbe7f3 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -38,6 +37,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_rt_detr import RTDetrConfig diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index b0b771c48678..28f33ddf55a1 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -23,7 +23,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...auto import AutoBackbone from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width from ...image_transforms import center_to_corners_format, corners_to_center_format @@ -52,6 +51,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from ..conditional_detr.modeling_conditional_detr import inverse_sigmoid from ..deformable_detr.modeling_deformable_detr import DeformableDetrMultiscaleDeformableAttention from ..detr.image_processing_detr_fast import DetrImageProcessorFast diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index 3921dd0791d1..3a31e82040a7 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -29,7 +29,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...auto import AutoBackbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -38,6 +37,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_rt_detr_v2 import RTDetrV2Config diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index b39c296d9645..fb4c5dec9e9a 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -21,7 +21,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -31,6 +30,7 @@ auto_docstring, logging, ) +from ..auto import AutoBackbone from .configuration_table_transformer import TableTransformerConfig diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 04aec13d06ed..801da7b399d9 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -21,11 +21,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ..auto import AutoBackbone from .configuration_tvp import TvpConfig diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 4217c2c5b832..61db23ce7c63 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -17,10 +17,10 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ...auto import AutoBackbone from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring +from ..auto import AutoBackbone from .configuration_upernet import UperNetConfig diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b4497c459993..7d05f1e5d3a6 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -19,9 +19,9 @@ from torch import nn from ... import initialization as init -from ...auto import AutoBackbone from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring +from ..auto import AutoBackbone from .configuration_vitmatte import VitMatteConfig diff --git a/src/transformers/models/vitpose/modeling_vitpose.py b/src/transformers/models/vitpose/modeling_vitpose.py index 66dcfb8d9f2a..f077a2e346c7 100644 --- a/src/transformers/models/vitpose/modeling_vitpose.py +++ b/src/transformers/models/vitpose/modeling_vitpose.py @@ -19,12 +19,12 @@ from torch import nn from ... import initialization as init -from ...auto import AutoBackbone from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging from ...utils.generic import can_return_tuple +from ..auto import AutoBackbone from .configuration_vitpose import VitPoseConfig diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 8cbb747d5209..399e677fa696 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -21,10 +21,10 @@ from ... import initialization as init from ...activations import ACT2FN -from ...auto import AutoBackbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ..auto import AutoBackbone from .configuration_zoedepth import ZoeDepthConfig From 8a6e379f94fc81c784f175e4a21a9a72fa54731c Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:17:13 +0100 Subject: [PATCH 09/22] last test and fix repo --- .../models/gemma3n/configuration_gemma3n.py | 57 +------------------ .../models/gemma3n/modeling_gemma3n.py | 8 +-- .../models/gemma3n/modular_gemma3n.py | 7 ++- .../test_modeling_timm_backbone.py | 2 +- tests/utils/test_backbone_utils.py | 23 +++++--- 5 files changed, 23 insertions(+), 74 deletions(-) diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index 37e076f5861e..c8193622d5ad 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -23,11 +23,7 @@ from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_rope_utils import RopeParameters -from ...utils import is_timm_available, logging, requires_backends - - -if is_timm_available(): - from timm.data import ImageNetInfo, infer_imagenet_subset +from ...utils import logging logger = logging.get_logger(__name__) @@ -502,57 +498,6 @@ def __init__( self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) - @classmethod - def from_dict(cls, config_dict: dict[str, Any], **kwargs): - # Create a copy to avoid mutating the original dict - config_dict = config_dict.copy() - - label_names = config_dict.get("label_names") - is_custom_model = "num_labels" in kwargs or "id2label" in kwargs - - # if no labels added to config, use imagenet labeller in timm - if label_names is None and not is_custom_model: - requires_backends(cls, ["timm"]) - imagenet_subset = infer_imagenet_subset(config_dict) - if imagenet_subset: - dataset_info = ImageNetInfo(imagenet_subset) - synsets = dataset_info.label_names() - label_descriptions = dataset_info.label_descriptions(as_dict=True) - label_names = [label_descriptions[synset] for synset in synsets] - - if label_names is not None and not is_custom_model: - kwargs["id2label"] = dict(enumerate(label_names)) - - # if all label names are unique, create label2id mapping as well - if len(set(label_names)) == len(label_names): - kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} - else: - kwargs["label2id"] = None - - # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. - # We are removing these attributes in order to have the native `transformers` num_labels attribute in config - # and to avoid duplicate attributes - num_labels_in_kwargs = kwargs.pop("num_labels", None) - num_labels_in_dict = config_dict.pop("num_classes", None) - - # passed num_labels has priority over num_classes in config_dict - kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict - - # pop num_classes from "pretrained_cfg", - # it is not necessary to have it, only root one is used in timm - if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: - config_dict["pretrained_cfg"].pop("num_classes", None) - - return super().from_dict(config_dict, **kwargs) - - def to_dict(self) -> dict[str, Any]: - output = super().to_dict() - output.setdefault("num_classes", self.num_labels) - output.setdefault("label_names", list(self.id2label.values())) - output.pop("id2label", None) - output.pop("label2id", None) - return output - class Gemma3nConfig(PreTrainedConfig): r""" diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index e22def7b0d87..65ecd63f035b 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -38,13 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - torch_compilable_check, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a97cc2823c7b..d563a7e57d4e 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -56,7 +56,6 @@ PaliGemmaModel, PaligemmaModelOutputWithPast, ) -from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig logger = logging.get_logger(__name__) @@ -444,7 +443,7 @@ def __init__( self.sscp_conv_stride_size = sscp_conv_stride_size -class Gemma3nVisionConfig(TimmWrapperConfig): +class Gemma3nVisionConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to instantiate an timm model model according to the specified arguments, defining the model architecture. @@ -510,6 +509,10 @@ def __init__( self.vocab_size = vocab_size self.vocab_offset = vocab_offset self.rms_norm_eps = rms_norm_eps + self.architecture = architecture + self.initializer_range = initializer_range + self.do_pooling = do_pooling + self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 9ab71fa94a50..5410b0e7298c 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -42,7 +42,7 @@ def __init__( batch_size=3, image_size=32, num_channels=3, - is_training=True, + is_training=False, ): self.parent = parent self.out_indices = out_indices if out_indices is not None else [4] diff --git a/tests/utils/test_backbone_utils.py b/tests/utils/test_backbone_utils.py index a27ced73018f..906986cd5d72 100644 --- a/tests/utils/test_backbone_utils.py +++ b/tests/utils/test_backbone_utils.py @@ -16,11 +16,18 @@ import pytest -from transformers import DetrConfig, MaskFormerConfig, PreTrainedConfig, ResNetBackbone, ResNetConfig, TimmBackbone +from transformers import ( + AutoBackbone, + DetrConfig, + MaskFormerConfig, + PreTrainedConfig, + ResNetBackbone, + ResNetConfig, + TimmBackbone, +) from transformers.backbone_utils import ( BackboneConfigMixin, BackboneMixin, - load_backbone, ) from transformers.testing_utils import require_torch, slow from transformers.utils.import_utils import is_torch_available @@ -160,7 +167,7 @@ def test_load_backbone_from_config(self): Test that load_backbone correctly loads a backbone from a backbone config. """ config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2))) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_features, ["stem", "stage2"]) self.assertEqual(backbone.out_indices, (0, 2)) self.assertIsInstance(backbone, ResNetBackbone) @@ -172,7 +179,7 @@ def test_load_backbone_from_checkpoint(self): Test that load_backbone correctly loads a backbone from a checkpoint. """ config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, [4]) self.assertEqual(backbone.out_features, ["stage4"]) self.assertIsInstance(backbone, ResNetBackbone) @@ -181,7 +188,7 @@ def test_load_backbone_from_checkpoint(self): backbone="resnet18", use_timm_backbone=True, ) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) # We can't know ahead of time the exact output features and indices, or the layer names before # creating the timm model, so it defaults to the last layer (-1,) and has a different layer name self.assertEqual(backbone.out_indices, (-1,)) @@ -195,12 +202,12 @@ def test_load_backbone_backbone_kwargs(self): Test that load_backbone correctly configures the loaded backbone with the provided kwargs. """ config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)}) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, (0, 1)) self.assertIsInstance(backbone, TimmBackbone) config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)}) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, (0, 2)) self.assertIsInstance(backbone, ResNetBackbone) @@ -223,7 +230,7 @@ def test_load_backbone_in_new_model(self): class NewModel(BertPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config.backbone_config) self.layer_0 = torch.nn.Linear(config.hidden_size, config.hidden_size) self.layer_1 = torch.nn.Linear(config.hidden_size, config.hidden_size) From 04e0d8ad1ab6a7041315c1da6be276f95e19d158 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:29:15 +0100 Subject: [PATCH 10/22] and here maybe --- src/transformers/backbone_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index 1d3a3501eeab..bacb3ea6973b 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -297,15 +297,17 @@ def consolidate_backbone_kwargs_to_config( and backbone_config is None and not backbone_kwargs ): - backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **timm_default_kwargs) + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone=backbone, **timm_default_kwargs) elif backbone is not None and backbone_config is None: if repo_exists(backbone): config_dict, _ = PreTrainedConfig.get_config_dict(backbone) + if config_dict["model_type"] == "timm_backbone": + config_dict["model_type"] = "timm_wrapper" config_class = CONFIG_MAPPING[config_dict["model_type"]] config_dict.update(backbone_kwargs) backbone_config = config_class(**config_dict) else: - backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **backbone_kwargs) + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone=backbone, **backbone_kwargs) elif backbone_config is None and default_config_type is not None: logger.info( f"`backbone_config` is `None`. Initializing the config with the default `{default_config_type}` vision config." @@ -314,6 +316,8 @@ def consolidate_backbone_kwargs_to_config( backbone_config = CONFIG_MAPPING[default_config_type](**default_config_kwargs) elif isinstance(backbone_config, dict): backbone_model_type = backbone_config.get("model_type") + if backbone_model_type == "timm_backbone": + backbone_model_type = "timm_wrapper" config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) From ba56a16b00fa0479b72fcf77ca71eeb00bc23184 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 24 Feb 2026 18:30:50 +0100 Subject: [PATCH 11/22] and one more test --- src/transformers/models/auto/auto_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 166c0b2190a4..47f56954a6b5 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -268,7 +268,7 @@ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *mod # Always load a pretrained model when `from_pretrained` is called kwargs.pop("use_pretrained_backbone", None) - return super().from_config(config, pretrained=True, **kwargs) + return cls.from_config(config, pretrained=True, **kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args, **kwargs): From a829360657d5adbe6f1989d247867112d2c2dce0 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 10:43:05 +0100 Subject: [PATCH 12/22] docstring --- src/transformers/models/timm_wrapper/modeling_timm_wrapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 8e04e13d16d7..042c9c41027e 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -259,6 +259,10 @@ def forward( **kwargs, ) -> TimmWrapperModelOutput | tuple[Tensor, ...]: r""" + do_pooling (`bool`, *optional*): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the + `do_pooling` value from the config is used. + Examples: ```python >>> import torch From 100c489065d85af312de77ffd6c5e7717cb97cb8 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 11:26:12 +0100 Subject: [PATCH 13/22] deprecation messages --- .../models/conditional_detr/modeling_conditional_detr.py | 6 +++--- .../models/deformable_detr/modeling_deformable_detr.py | 6 +++--- .../models/deformable_detr/modular_deformable_detr.py | 6 +++--- src/transformers/models/detr/modeling_detr.py | 6 +++--- .../models/omdet_turbo/convert_omdet_turbo_to_hf.py | 2 +- .../models/table_transformer/modeling_table_transformer.py | 6 +++--- .../models/timm_backbone/configuration_timm_backbone.py | 5 ++++- .../models/timm_backbone/modeling_timm_backbone.py | 5 ++++- .../models/timm_wrapper/modeling_timm_wrapper.py | 4 +++- 9 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 3d3d81450fb9..fe6a66c79b8c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -266,10 +266,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 82bb58b200ae..e269c3c708b2 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -306,10 +306,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/deformable_detr/modular_deformable_detr.py b/src/transformers/models/deformable_detr/modular_deformable_detr.py index 886302bca645..118ea209ad54 100644 --- a/src/transformers/models/deformable_detr/modular_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modular_deformable_detr.py @@ -300,10 +300,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index ed50792d9e7b..d45427ba9c8a 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -266,10 +266,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py index 79aaddc2a8ba..093337f4ec72 100644 --- a/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py +++ b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py @@ -69,7 +69,7 @@ def create_rename_keys_vision(state_dict, config): for layer_name in state_dict: if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"): if config.use_timm_backbone: - layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone") + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone.timm_model") layer_name_replace = layer_name_replace.replace(".layers.", ".layers_") if "downsample" in layer_name: # get layer number diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index fb4c5dec9e9a..ccc203559807 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -212,10 +212,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index d4a9b0263cce..ef15e98ebb28 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -85,7 +85,10 @@ def __init__( "in_chans": num_channels, "output_stride": output_stride, } - logger.warning("Deprecation message!") + logger.warning( + "TimmBackboneConfig is deprecate and will be removed in future versions. " + "Use a TimmWrapperConfig instead with TimmWrapperBackboneModel to extract features." + ) super().__init__(freeze_batch_norm_2d=freeze_batch_norm_2d, **kwargs) def __setattr__(self, key, value): diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index 4041e1879900..7e57d64b3514 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -22,7 +22,10 @@ class TimmBackbone(TimmWrapperBackboneModel): def __init__(self, *args, **kwargs): - logger.warning("Deprecation message") + logger.warning( + "`TimmBackbone` is deprecate and will be removed in future versions. Use a " + "`TimmWrapperBackboneModel` init from `TimmWrapperConfig` instead to extract features." + ) super().__init__(*args, **kwargs) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 042c9c41027e..a7d77815d909 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -197,7 +197,9 @@ def __init__(self, config, **kwargs): @property def _backbone(self): - logger.warning("Deprecation msg") + logger.warning( + f"The `self._backbone` attribute is deprecated for {self.__class__.__name__}. Please use `self.timm_model` instead." + ) return self.timm_model @can_return_tuple From 854f2fb9e4e37385bd32e1e9e15da77d6377f565 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 12:41:38 +0100 Subject: [PATCH 14/22] add test and docs --- docs/source/en/model_doc/timm_wrapper.md | 5 +++ .../test_modeling_timm_wrapper.py | 34 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/timm_wrapper.md b/docs/source/en/model_doc/timm_wrapper.md index c6e508b9ffe1..81cf6e04ba81 100644 --- a/docs/source/en/model_doc/timm_wrapper.md +++ b/docs/source/en/model_doc/timm_wrapper.md @@ -71,6 +71,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] TimmWrapperImageProcessor - preprocess +## TimmWrapperBackboneModel + +[[autodoc]] TimmWrapperBackboneModel + - forward + ## TimmWrapperModel [[autodoc]] TimmWrapperModel diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 46ec5c01fe0f..969e5c45cee3 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -28,6 +28,7 @@ ) from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -36,7 +37,12 @@ if is_torch_available(): import torch - from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel + from transformers import ( + TimmWrapperBackboneModel, + TimmWrapperConfig, + TimmWrapperForImageClassification, + TimmWrapperModel, + ) if is_timm_available(): @@ -253,6 +259,32 @@ def test_model_init_args(self): self.assertEqual(len(restored_model.timm_model.blocks), 3) +class TimmWrapperBackboneModelTester(TimmWrapperModelTester): + def __init__(self, parent, **kwargs): + super().__init__(parent, **kwargs) + # Add `features_only` for the backbone model + self.model_args = {"channels": (16, 16, 16, 16), "features_only": True} + + +@require_torch +class TimmWrapperBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (TimmWrapperBackboneModel,) if is_torch_available() else () + has_attentions = False + config_class = TimmWrapperConfig + + def setUp(self): + self.model_tester = TimmWrapperBackboneModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=self.config_class, + has_text_modality=False, + common_properties=["architecture"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + # We will verify our results on an image of cute cats def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") From c6476bcf5cfe58afcaee852cb2c9d8fc2c6135ea Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 15:14:57 +0100 Subject: [PATCH 15/22] allow loading old checkpoints, if any. Official ckpt don't exist with timm ig --- src/transformers/conversion_mapping.py | 14 +++++++------- src/transformers/modeling_utils.py | 20 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 03c056c9124e..6ad420c6b44c 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -295,13 +295,13 @@ def _build_checkpoint_conversion_mapping(): operations=[MergeModulelist(dim=0)], ), ], - "timm_wrapper": [ - # Simply add the prefix `timm_model` - # TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming + "timm_backbone": [ + # For BC with backbone model after deprecating `TimmBackbone` model class + # TODO: the conversion mapping doesn't work well with literal dots (r'\.') in source WeightRenaming( - source_patterns=r"(.+)", - target_patterns=r"timm_model.\1", - ) + source_patterns=r"(? Date: Wed, 25 Feb 2026 15:28:06 +0100 Subject: [PATCH 16/22] dont create imagenet 1000 labels by default --- .../configuration_timm_wrapper.py | 31 +++++-------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 55250fb90625..1e2365c64c3a 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -18,11 +18,11 @@ from ...backbone_utils import BackboneConfigMixin from ...configuration_utils import PreTrainedConfig -from ...utils import is_timm_available, logging, requires_backends +from ...utils import is_timm_available, logging if is_timm_available(): - from timm.data import ImageNetInfo, infer_imagenet_subset + pass logger = logging.get_logger(__name__) @@ -98,17 +98,6 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): label_names = config_dict.get("label_names") is_custom_model = "num_labels" in kwargs or "id2label" in kwargs - - # if no labels added to config, use imagenet labeller in timm - if label_names is None and not is_custom_model: - requires_backends(cls, ["timm"]) - imagenet_subset = infer_imagenet_subset(config_dict) - if imagenet_subset: - dataset_info = ImageNetInfo(imagenet_subset) - synsets = dataset_info.label_names() - label_descriptions = dataset_info.label_descriptions(as_dict=True) - label_names = [label_descriptions[synset] for synset in synsets] - if label_names is not None and not is_custom_model: kwargs["id2label"] = dict(enumerate(label_names)) @@ -120,17 +109,13 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. # We are removing these attributes in order to have the native `transformers` num_labels attribute in config - # and to avoid duplicate attributes - num_labels_in_kwargs = kwargs.pop("num_labels", None) - num_labels_in_dict = config_dict.pop("num_classes", None) - - # passed num_labels has priority over num_classes in config_dict - kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + # and to avoid duplicate attributes. Noe that `num_labels` has priority over `num_classes` in config_dict + if config_dict.get("num_classes") is not None and kwargs.get("num_labels") is None: + kwargs["num_labels"] = config_dict.pop("num_classes", None) - # pop num_classes from "pretrained_cfg", - # it is not necessary to have it, only root one is used in timm - if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: - config_dict["pretrained_cfg"].pop("num_classes", None) + # Pop in nested `pretrained_cfg` as well + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) return super().from_dict(config_dict, **kwargs) From ab6def35d8c26b84653bd77af00d48467eb4ed6b Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 17:02:32 +0100 Subject: [PATCH 17/22] fix tests --- src/transformers/backbone_utils.py | 19 +++++++++++++++---- src/transformers/conversion_mapping.py | 1 - src/transformers/modeling_utils.py | 17 ++++++++--------- .../configuration_conditional_detr.py | 10 ++++++---- .../models/dab_detr/configuration_dab_detr.py | 10 ++++++---- .../configuration_deformable_detr.py | 10 ++++++---- .../models/detr/configuration_detr.py | 10 ++++++---- .../omdet_turbo/configuration_omdet_turbo.py | 8 ++++++-- .../configuration_table_transformer.py | 8 +++++--- .../configuration_timm_wrapper.py | 10 +--------- .../timm_wrapper/modeling_timm_wrapper.py | 5 ++++- 11 files changed, 63 insertions(+), 45 deletions(-) diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index bacb3ea6973b..0cc9041dc280 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -297,17 +297,22 @@ def consolidate_backbone_kwargs_to_config( and backbone_config is None and not backbone_kwargs ): - backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone=backbone, **timm_default_kwargs) + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **timm_default_kwargs) elif backbone is not None and backbone_config is None: if repo_exists(backbone): config_dict, _ = PreTrainedConfig.get_config_dict(backbone) - if config_dict["model_type"] == "timm_backbone": - config_dict["model_type"] = "timm_wrapper" config_class = CONFIG_MAPPING[config_dict["model_type"]] config_dict.update(backbone_kwargs) backbone_config = config_class(**config_dict) else: - backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone=backbone, **backbone_kwargs) + # Move timm-args inside `model_args` to support loading from TimmBackboneConfig + if "model_args" not in backbone_kwargs: + backbone_kwargs["model_args"] = { + "in_chans": backbone_kwargs.pop("num_channels", 3), + "features_only": backbone_kwargs.pop("features_only", True), + "output_stride": backbone_kwargs.pop("output_stride", None), + } + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **backbone_kwargs) elif backbone_config is None and default_config_type is not None: logger.info( f"`backbone_config` is `None`. Initializing the config with the default `{default_config_type}` vision config." @@ -318,6 +323,12 @@ def consolidate_backbone_kwargs_to_config( backbone_model_type = backbone_config.get("model_type") if backbone_model_type == "timm_backbone": backbone_model_type = "timm_wrapper" + # Move timm-args inside `model_args` + backbone_config["model_args"] = { + "in_chans": backbone_config.pop("num_channels", 3), + "features_only": backbone_config.pop("features_only", True), + "output_stride": backbone_config.pop("output_stride", None), + } config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 6ad420c6b44c..d13b8daeb195 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -422,7 +422,6 @@ def get_model_conversion_mapping( for k, v in model._checkpoint_conversion_mapping.items() ] - # weight_conversions = get_weight_conversions_resursively(model, weight_conversions=weight_conversions) model_type = getattr(model.config, "model_type", None) if model_type is not None: model_specific_conversions = get_checkpoint_conversion_mapping(model_type) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e586e32a3a11..fbba6696fed5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4067,9 +4067,7 @@ def from_pretrained( dtype_plan = model._get_dtype_plan(dtype) # Obtain the weight conversion mapping for this model if any are registered and appy to all submodels recursively - weight_conversions = cls.get_weight_conversions_recursively( - model, key_mapping, hf_quantizer, weight_conversions=[] - ) + weight_conversions = cls.get_weight_conversions_recursively(model, key_mapping, hf_quantizer) if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size) @@ -4268,18 +4266,19 @@ def _finalize_model_loading( return loading_info @classmethod - def get_weight_conversions_recursively(cls, model, key_mapping, hf_quantizer, weight_conversions): + def get_weight_conversions_recursively(cls, model, key_mapping, hf_quantizer): + conversions = [] + conversions.extend(get_model_conversion_mapping(model, key_mapping, hf_quantizer)) + for submodule in model.modules(): if ( submodule is not model and isinstance(submodule, PreTrainedModel) and submodule.config.__class__ != model.config.__class__ ): - weight_conversions_submodel = get_model_conversion_mapping( - submodule, key_mapping, hf_quantizer, weight_conversions - ) - weight_conversions.extend(weight_conversions_submodel) - return weight_conversions + conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer)) + conversions.extend(cls.get_weight_conversions_recursively(submodule, key_mapping, hf_quantizer)) + return conversions def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = {".".join(key.split(".")[:-1]) for key in names} diff --git a/src/transformers/models/conditional_detr/configuration_conditional_detr.py b/src/transformers/models/conditional_detr/configuration_conditional_detr.py index a8ad94920057..de55f837d578 100644 --- a/src/transformers/models/conditional_detr/configuration_conditional_detr.py +++ b/src/transformers/models/conditional_detr/configuration_conditional_detr.py @@ -161,13 +161,15 @@ def __init__( # Init timm backbone with hardcoded values for BC backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/dab_detr/configuration_dab_detr.py b/src/transformers/models/dab_detr/configuration_dab_detr.py index f4e3d2e062ad..abc87297b063 100644 --- a/src/transformers/models/dab_detr/configuration_dab_detr.py +++ b/src/transformers/models/dab_detr/configuration_dab_detr.py @@ -175,13 +175,15 @@ def __init__( # Init timm backbone with hardcoded values for BC timm_default_kwargs = { - "num_channels": 3, - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": 3, + "features_only": True, + "pretrained": False, + }, "out_indices": [1, 2, 3, 4], } if dilation: - timm_default_kwargs["output_stride"] = 16 + timm_default_kwargs["model_args"]["output_stride"] = 16 backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/deformable_detr/configuration_deformable_detr.py b/src/transformers/models/deformable_detr/configuration_deformable_detr.py index ad7f56eb763a..78f57d132cb5 100644 --- a/src/transformers/models/deformable_detr/configuration_deformable_detr.py +++ b/src/transformers/models/deformable_detr/configuration_deformable_detr.py @@ -185,13 +185,15 @@ def __init__( ): # Init timm backbone with hardcoded values for BC timm_default_kwargs = { - "num_channels": 3, - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": 3, + "features_only": True, + "pretrained": False, + }, "out_indices": [2, 3, 4] if num_feature_levels > 1 else [4], } if dilation: - timm_default_kwargs["output_stride"] = 16 + timm_default_kwargs["model_args"]["output_stride"] = 16 backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 87835ba098a2..4fc6728bcd90 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -157,13 +157,15 @@ def __init__( ): backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py index e662f318a69b..a00ca746ce0a 100644 --- a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py @@ -187,9 +187,13 @@ def __init__( ): # Init timm backbone with hardcoded values for BC timm_default_kwargs = { + "model_args": { + "img_size": image_size, + "features_only": True, + "pretrained": False, + "always_partition": True, + }, "out_indices": [1, 2, 3], - "img_size": image_size, - "always_partition": True, } backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/table_transformer/configuration_table_transformer.py b/src/transformers/models/table_transformer/configuration_table_transformer.py index f8c3b2e79320..9c1de7eb78af 100644 --- a/src/transformers/models/table_transformer/configuration_table_transformer.py +++ b/src/transformers/models/table_transformer/configuration_table_transformer.py @@ -158,9 +158,11 @@ def __init__( ): backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 1e2365c64c3a..b0f7da1d2414 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -76,18 +76,10 @@ def __init__( model_args: dict[str, Any] | None = None, **kwargs, ): - is_backbone_config = kwargs.get("backbone") is not None - self.architecture = kwargs.pop("backbone") if is_backbone_config else architecture + self.architecture = architecture self.initializer_range = initializer_range self.do_pooling = do_pooling self.freeze_batch_norm_2d = freeze_batch_norm_2d - if model_args is None and is_backbone_config: - model_args = { - "features_only": kwargs.pop("features_only", True), - "in_chans": kwargs.pop("num_channels", 3), - "output_stride": kwargs.get("output_stride"), - } - self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index a7d77815d909..149debf17e0a 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -316,7 +316,10 @@ def forward( pixel_values = pixel_values.to(self.device) if self.features_only: - # TODO: ideally features only should be used with `BackboneModel`, deprecate here! + logger.warning( + f"Using a `features_only` mode in {self.__class__.__name__} is deprecated and will be removed in v5.20.0" + "Instead please use `TimmWrapperBackboneModel` to obtain feature maps." + ) last_hidden_state = self.timm_model.forward(pixel_values) hidden_states = last_hidden_state if output_hidden_states else None pooler_output = None From 1d94e420b56f091c6d941ccdf8f733b43ac364b3 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 17:12:28 +0100 Subject: [PATCH 18/22] docs --- docs/source/en/main_classes/backbones.md | 2 +- .../models/timm_backbone/configuration_timm_backbone.py | 2 ++ .../models/timm_backbone/modeling_timm_backbone.py | 6 ++++++ .../models/timm_wrapper/configuration_timm_wrapper.py | 6 +----- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/source/en/main_classes/backbones.md b/docs/source/en/main_classes/backbones.md index 3a0291bda898..6cc1b5034a54 100644 --- a/docs/source/en/main_classes/backbones.md +++ b/docs/source/en/main_classes/backbones.md @@ -21,7 +21,7 @@ A backbone is a model used for feature extraction for higher level computer visi * [`~backbone_utils.BackboneMixin`] enables initializing a backbone from Transformers or [timm](https://hf.co/docs/timm/index) and includes functions for returning the output features and indices. * [`~backbone_utils.BackboneConfigMixin`] sets the output features and indices of the backbone configuration. -[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmBackbone`] and [`TimmBackboneConfig`] classes. +[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmWrapperBackboneModel`] and [`TimmWrapperConfig`] classes. Backbones are supported for the following models: diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index ef15e98ebb28..6f63b4bfcc4c 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -30,6 +30,8 @@ class TimmBackboneConfig(TimmWrapperConfig): Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. + Note that the config class is deprecated, use `TimmWrapperConfig` instead! + Args: backbone (`str`, *optional*): The timm checkpoint to load. diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index 7e57d64b3514..973d8e7b7b41 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -21,6 +21,12 @@ class TimmBackbone(TimmWrapperBackboneModel): + """ + Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the + other models in the library keeping the same API. + Note that the model is deprecated, use `TimmWrapperBackboneModel` instead! + """ + def __init__(self, *args, **kwargs): logger.warning( "`TimmBackbone` is deprecate and will be removed in future versions. Use a " diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index b0f7da1d2414..82b368101a85 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -18,11 +18,7 @@ from ...backbone_utils import BackboneConfigMixin from ...configuration_utils import PreTrainedConfig -from ...utils import is_timm_available, logging - - -if is_timm_available(): - pass +from ...utils import logging logger = logging.get_logger(__name__) From 2f367188433cde598a13c0a20055cce5116dfb0c Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 17:19:53 +0100 Subject: [PATCH 19/22] fix copies --- .../models/table_transformer/configuration_table_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/table_transformer/configuration_table_transformer.py b/src/transformers/models/table_transformer/configuration_table_transformer.py index 9c1de7eb78af..2d293f22b5f1 100644 --- a/src/transformers/models/table_transformer/configuration_table_transformer.py +++ b/src/transformers/models/table_transformer/configuration_table_transformer.py @@ -166,7 +166,7 @@ def __init__( "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, From 882b555f7cc2948dc5ca00e585a96d09455b4962 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 21:13:51 +0100 Subject: [PATCH 20/22] fix the rest of tests! --- src/transformers/conversion_mapping.py | 1 - src/transformers/core_model_loading.py | 4 +- src/transformers/integrations/peft.py | 3 +- src/transformers/modeling_utils.py | 17 +- .../models/colpali/modeling_colpali.py | 7 - .../models/maskformer/modeling_maskformer.py | 354 ++++++++++-------- .../models/paligemma/modeling_paligemma.py | 6 +- .../test_modeling_timm_wrapper.py | 7 +- 8 files changed, 217 insertions(+), 182 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index d13b8daeb195..b604d2611418 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -372,7 +372,6 @@ def register_checkpoint_conversion_mapping( VLMS = [ "aria", "ayavision", - "colpali", "emu3", "fuyu", "gotocr2", diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index a428a472ad85..b9a8ebd395ba 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1257,10 +1257,8 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # In this case, the model was not created with `from_pretrained` -> let's check if it's in the hardcoded # mappings, and recreate the mapping from there if it is if weight_conversions is None: - from .conversion_mapping import get_model_conversion_mapping - # Do not resave with the legacy renaming, if present - weight_conversions = get_model_conversion_mapping(model, add_legacy=False) + weight_conversions = model.get_weight_conversions_recursively(add_legacy=False) weight_conversions = weight_conversions if len(weight_conversions) > 0 else None # We did not find any operations to perform -> quick escape diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index ff8734c510ad..01f4785943b9 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -22,7 +22,6 @@ from ..conversion_mapping import ( _MODEL_TO_CONVERSION_PATTERN, get_checkpoint_conversion_mapping, - get_model_conversion_mapping, ) from ..core_model_loading import ( Concatenate, @@ -519,7 +518,7 @@ def load_adapter( **load_config.download_kwargs, ) - weight_conversions = get_model_conversion_mapping(self) + weight_conversions = self.get_weight_conversions_recursively() peft_config = convert_peft_config_for_transformers(peft_config, model=self, conversions=weight_conversions) if hasattr(peft_config, "inference_mode"): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fbba6696fed5..a460b8238bea 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4067,7 +4067,7 @@ def from_pretrained( dtype_plan = model._get_dtype_plan(dtype) # Obtain the weight conversion mapping for this model if any are registered and appy to all submodels recursively - weight_conversions = cls.get_weight_conversions_recursively(model, key_mapping, hf_quantizer) + weight_conversions = model.get_weight_conversions_recursively(key_mapping, hf_quantizer) if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size) @@ -4265,19 +4265,18 @@ def _finalize_model_loading( return loading_info - @classmethod - def get_weight_conversions_recursively(cls, model, key_mapping, hf_quantizer): + def get_weight_conversions_recursively(self, key_mapping=None, hf_quantizer=None, add_legacy=True): conversions = [] - conversions.extend(get_model_conversion_mapping(model, key_mapping, hf_quantizer)) + conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer)) - for submodule in model.modules(): + for submodule in self.children(): if ( - submodule is not model + submodule is not self and isinstance(submodule, PreTrainedModel) - and submodule.config.__class__ != model.config.__class__ + and submodule.config.__class__ != self.config.__class__ ): - conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer)) - conversions.extend(cls.get_weight_conversions_recursively(submodule, key_mapping, hf_quantizer)) + conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer, add_legacy)) + conversions.extend(submodule.get_weight_conversions_recursively(key_mapping, hf_quantizer, add_legacy)) return conversions def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index f862bcf943e8..967ec0ba603a 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -101,13 +101,6 @@ class ColPaliForRetrievalOutput(ModelOutput): """ ) class ColPaliForRetrieval(ColPaliPreTrainedModel): - _checkpoint_conversion_mapping = { - "vlm.language_model.model": "vlm.model.language_model", - "vlm.vision_tower": "vlm.model.vision_tower", - "vlm.multi_modal_projector": "vlm.model.multi_modal_projector", - "vlm.language_model.lm_head": "vlm.lm_head", - } - def __init__(self, config: ColPaliConfig): super().__init__(config) self.config = config diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 1ad4e0417e98..f53199482af3 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -14,6 +14,7 @@ """PyTorch MaskFormer model.""" import math +from collections.abc import Callable from dataclasses import dataclass from numbers import Number @@ -26,7 +27,7 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ( @@ -387,206 +388,262 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = return loss / height_and_width -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrAttention -class DetrAttention(nn.Module): +# Copied from transformers.models.detr.modeling_detr.DetrMLP +class DetrMLP(nn.Module): + def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.detr.modeling_detr.DetrSelfAttention +class DetrSelfAttention(nn.Module): """ - Multi-headed attention from 'Attention Is All You Need' paper. + Multi-headed self-attention from 'Attention Is All You Need' paper. - Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + In DETR, position embeddings are added to both queries and keys (but not values) in self-attention. """ def __init__( self, - embed_dim: int, - num_heads: int, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, dropout: float = 0.0, bias: bool = True, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - if self.head_dim * num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {num_heads})." - ) + self.config = config + self.head_dim = hidden_size // num_attention_heads self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): - return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None): - return tensor if object_queries is None else tensor + object_queries + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - key_value_states: torch.Tensor | None = None, - spatial_position_embeddings: torch.Tensor | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size, target_len, embed_dim = hidden_states.size() - - # add position embeddings to the hidden states before projecting to queries and keys - if object_queries is not None: - hidden_states_original = hidden_states - hidden_states = self.with_pos_embed(hidden_states, object_queries) - - # add key-value position embeddings to the key value states - if spatial_position_embeddings is not None: - key_value_states_original = key_value_states - key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states - source_len = key_states.size(1) + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) - if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): - raise ValueError( - f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" - f" {attn_weights.size()}" - ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, target_len, source_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" - f" {attention_mask.size()}" - ) - if attention_mask.dtype == torch.bool: - attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( - attention_mask, -torch.inf - ) - attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) - else: - attn_weights_reshaped = None + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) +# Copied from transformers.models.detr.modeling_detr.DetrCrossAttention +class DetrCrossAttention(nn.Module): + """ + Multi-headed cross-attention from 'Attention Is All You Need' paper. - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + In DETR, queries get their own position embeddings, while keys get encoder position embeddings. + Values don't get any position embeddings. + """ - attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + def __init__( + self, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - attn_output = self.out_proj(attn_output) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) - return attn_output, attn_weights_reshaped + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + encoder_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings logic: + - Queries get position_embeddings + - Keys get encoder_position_embeddings + - Values don't get any position embeddings + """ + query_input_shape = hidden_states.shape[:-1] + query_hidden_shape = (*query_input_shape, -1, self.head_dim) + + kv_input_shape = key_value_states.shape[:-1] + kv_hidden_shape = (*kv_input_shape, -1, self.head_dim) + + query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + key_input = ( + key_value_states + encoder_position_embeddings + if encoder_position_embeddings is not None + else key_value_states + ) + query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2) + key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer + attn_output = attn_output.reshape(*query_input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() - self.embed_dim = config.d_model + self.hidden_size = config.d_model - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, + self.self_attn = DetrSelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = DetrAttention( - self.embed_dim, - config.decoder_attention_heads, + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.encoder_attn = DetrCrossAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - query_position_embeddings: torch.Tensor | None = None, + spatial_position_embeddings: torch.Tensor | None = None, + object_queries_position_embeddings: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = False, **kwargs: Unpack[TransformersKwargs], - ): + ) -> torch.Tensor: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - object_queries (`torch.FloatTensor`, *optional*): - object_queries that are added to the hidden states - in the cross-attention layer. - query_position_embeddings (`torch.FloatTensor`, *optional*): - position embeddings that are added to the queries and keys - in the self-attention layer. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only + in the cross-attention layer (not to values). + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. In self-attention, these are added to both queries + and keys (not values). In cross-attention, these are added to queries only (not to keys or values). encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. """ residual = hidden_states # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, + position_embeddings=object_queries_position_embeddings, attention_mask=attention_mask, - output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -594,17 +651,16 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - spatial_position_embeddings=object_queries, - output_attentions=output_attentions, + position_embeddings=object_queries_position_embeddings, + encoder_position_embeddings=spatial_position_embeddings, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -613,19 +669,11 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states class DetrDecoder(PreTrainedModel): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 15f2071ee2bc..ddb945cf611a 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -429,10 +429,8 @@ def forward( ) class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", + "language_model.model": "model.language_model", + "language_model.lm_head": "lm_head", } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 969e5c45cee3..820e614da045 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -193,13 +193,14 @@ def test_do_pooling_option(self): self.assertIsNotNone(output.pooler_output) def test_timm_config_labels(self): - # test timm config with no labels + # test timm config with no labels, default labels are created for 100 classes checkpoint = "timm/resnet18.a1_in1k" config = TimmWrapperConfig.from_pretrained(checkpoint) - self.assertIsNone(config.label2id) + self.assertIsInstance(config.label2id, dict) self.assertIsInstance(config.id2label, dict) + self.assertEqual(len(config.id2label), len(config.label2id)) self.assertEqual(len(config.id2label), 1000) - self.assertEqual(config.id2label[1], "goldfish, Carassius auratus") + self.assertEqual(config.id2label[1], "LABEL_1") # test timm config with labels in config checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21" From 9794fcfdda4e59c788eefe4fcf7b4f4105928467 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 25 Feb 2026 21:23:04 +0100 Subject: [PATCH 21/22] move paligemma to proper mapping and see --- src/transformers/conversion_mapping.py | 5 ++++- src/transformers/modeling_utils.py | 2 +- src/transformers/models/paligemma/modeling_paligemma.py | 5 ----- tests/test_modeling_common.py | 7 +++---- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index b604d2611418..72c0b22fdcd3 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -67,6 +67,10 @@ def _build_checkpoint_conversion_mapping(): mapping = { + "paligemma": [ + WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), + WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), + ], "qwen3_5_text": [ WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), ], @@ -380,7 +384,6 @@ def register_checkpoint_conversion_mapping( "llava", # all llava prefixed models fall under this check "mistral3", "mllama", - "paligemma", "shieldgemma2", "qwen2vl", "qwen2_5_vl", diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a460b8238bea..33db328827a8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4267,7 +4267,7 @@ def _finalize_model_loading( def get_weight_conversions_recursively(self, key_mapping=None, hf_quantizer=None, add_legacy=True): conversions = [] - conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer)) + conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer, add_legacy)) for submodule in self.children(): if ( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index ddb945cf611a..d5465e61763e 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -236,7 +236,6 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): """ ) class PaliGemmaModel(PaliGemmaPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False @@ -428,10 +427,6 @@ def forward( """ ) class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "language_model.model": "model.language_model", - "language_model.lm_head": "lm_head", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: PaliGemmaConfig): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf876c0ae813..6c78e5494a34 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -43,7 +43,6 @@ logging, set_seed, ) -from transformers.conversion_mapping import get_model_conversion_mapping from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( @@ -4673,7 +4672,7 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): with self.subTest(model_class.__name__): model = model_class(copy.deepcopy(config)) # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) + conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: self.skipTest("No conversion found for this model") @@ -4707,7 +4706,7 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): self.assertTrue( num_matches > 0, f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " - "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly", + "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", ) # If everything is still good at this point, let's test that we perform the same operations both when @@ -4742,7 +4741,7 @@ def test_can_load_from_already_mapped_keys(self): model = model_class(copy.deepcopy(config)) # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) + conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: self.skipTest("No conversion found for this model") From f13720c2146016891040e82f4806453c2e697bf0 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 26 Feb 2026 14:34:27 +0100 Subject: [PATCH 22/22] push --- src/transformers/conversion_mapping.py | 27 ++-- src/transformers/core_model_loading.py | 2 + .../models/colqwen2/configuration_colqwen2.py | 2 + .../models/colqwen2/modeling_colqwen2.py | 2 - .../models/colqwen2/modular_colqwen2.py | 2 - .../models/llava/modeling_llava.py | 10 -- .../models/maskformer/modeling_maskformer.py | 132 ++++++------------ .../models/qwen2_vl/modeling_qwen2_vl.py | 5 - .../models/qwen2_vl/test_modeling_qwen2_vl.py | 6 + tests/test_modeling_common.py | 6 +- 10 files changed, 72 insertions(+), 122 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 72c0b22fdcd3..ac00fe4a97cd 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -55,7 +55,6 @@ "qwen3_omni_moe": "qwen2_moe", "qwen3_omni_moe_thinker": "qwen2_moe", "qwen3_next": "qwen2_moe", - "qwen3_5_moe": "qwen2_moe", "hunyuan_v1_moe": "qwen2_moe", "flex_olmo": "qwen2_moe", "olmoe": "qwen2_moe", @@ -71,18 +70,22 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), ], + "llava": [ + WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), + WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), + ], + "qwen2_vl": [ + WeightRenaming( + source_patterns=r"(^|\.)model(?!\.(language_model|visual))", target_patterns="model.language_model" + ), + ], "qwen3_5_text": [ WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), ], - "t5gemma2": [ - WeightRenaming(r"(? tuple[str, str | None]: # Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of # Qwen2.5, Sam3, Ernie4.5 VL MoE! pattern = re.sub(r"\(\?.+\)", "", pattern) + # Remove the backslash for literal dots + pattern = pattern.replace(r"\.", ".") # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) capturing_group_match = re.search(r"\(.+?\)", pattern) captured_group = None diff --git a/src/transformers/models/colqwen2/configuration_colqwen2.py b/src/transformers/models/colqwen2/configuration_colqwen2.py index b168f502b3de..f1997a8882e5 100644 --- a/src/transformers/models/colqwen2/configuration_colqwen2.py +++ b/src/transformers/models/colqwen2/configuration_colqwen2.py @@ -61,6 +61,7 @@ def __init__( vlm_config=None, embedding_dim: int = 128, initializer_range: float = 0.02, + tie_word_embeddings: bool = True, **kwargs, ): if vlm_config is None: @@ -86,6 +87,7 @@ def __init__( self.vlm_config = vlm_config self.embedding_dim = embedding_dim self.initializer_range = initializer_range + self.tie_word_embeddings = tie_word_embeddings super().__init__(**kwargs) def get_text_config(self, *args, **kwargs) -> PreTrainedConfig: diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 5a1684571231..5c2180bdd236 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -105,8 +105,6 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): - _checkpoint_conversion_mapping = {} - def __init__(self, config: ColQwen2Config): super().__init__(config) self.config = config diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index b7a8da6364d2..c120859f2f94 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -257,8 +257,6 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColPaliForRetrieval): - _checkpoint_conversion_mapping = {} - def __init__(self, config: ColQwen2Config): super().__init__(config) del self._tied_weights_keys diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index a52aaa1cda51..0e63987abbbc 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -128,10 +128,6 @@ class LlavaPreTrainedModel(PreTrainedModel): """ ) class LlavaModel(LlavaPreTrainedModel): - _checkpoint_conversion_mapping = { - r"^language_model.model": "language_model", - } - def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) @@ -282,12 +278,6 @@ def forward( """ ) class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - r"^language_model.model": "model.language_model", - r"^vision_tower": "model.vision_tower", - r"^multi_modal_projector": "model.multi_modal_projector", - r"^language_model.lm_head": "lm_head", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: LlavaConfig): diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index f53199482af3..43250ef96983 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -39,6 +39,8 @@ logging, requires_backends, ) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs from ..auto import AutoBackbone from ..detr import DetrConfig from .configuration_maskformer import MaskFormerConfig @@ -676,51 +678,45 @@ def forward( return hidden_states +# copied from transformers.models.detr.modeling_detr.DetrDecoder with DetrPreTrainedModel->PreTrainedModel class DetrDecoder(PreTrainedModel): - config: DetrConfig - base_model_prefix = "model" - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. - - The decoder updates the query embeddings through multiple self-attention and cross-attention layers. - - Some small tweaks for DETR: - - - object_queries and query_position_embeddings are added to the forward pass. - - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules, + which apply self-attention to the queries and cross-attention to the encoder's outputs. Args: - config: DetrConfig + config (`DetrConfig`): Model configuration object. """ + _can_record_outputs = { + "hidden_states": DetrDecoderLayer, + "attentions": DetrSelfAttention, + "cross_attentions": DetrCrossAttention, + } + def __init__(self, config: DetrConfig): super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) # in DETR, the decoder uses layernorm after the last decoder layer output self.layernorm = nn.LayerNorm(config.d_model) - self.gradient_checkpointing = False - + # Initialize weights and apply final processing self.post_init() + @merge_with_config_defaults + @capture_outputs def forward( self, inputs_embeds=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - object_queries=None, - query_position_embeddings=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + spatial_position_embeddings=None, + object_queries_position_embeddings=None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> DetrDecoderOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -743,99 +739,59 @@ def forward( - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Position embeddings that are added to the queries and keys in each cross-attention layer. - query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer. + object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds - encoder_attention_mask = create_bidirectional_mask( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - ) + if attention_mask is not None: + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + ) + + # expand encoder attention mask (for cross-attention on encoder outputs) + if encoder_hidden_states is not None and encoder_attention_mask is not None: + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, - None, # attention_mask - object_queries, - query_position_embeddings, + attention_mask, + spatial_position_embeddings, + object_queries_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] - if self.config.auxiliary_loss: hidden_states = self.layernorm(hidden_states) intermediate += (hidden_states,) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - # finally, apply layernorm hidden_states = self.layernorm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - # stack intermediate decoder activations if self.config.auxiliary_loss: intermediate = torch.stack(intermediate) - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] - if v is not None - ) - return DetrDecoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - intermediate_hidden_states=intermediate, - ) + return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate) # refactored from original implementation @@ -1476,8 +1432,8 @@ def forward( attention_mask=None, encoder_hidden_states=image_features, encoder_attention_mask=None, - object_queries=object_queries, - query_position_embeddings=queries_embeddings, + spatial_position_embeddings=object_queries, + object_queries_position_embeddings=queries_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 3a403942fc18..d5cd0cb13c5e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -972,7 +972,6 @@ def forward( @auto_docstring class Qwen2VLModel(Qwen2VLPreTrainedModel): base_model_prefix = "model" - _checkpoint_conversion_mapping = {"^model": "language_model"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False @@ -1370,10 +1369,6 @@ def forward( class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^visual": "model.visual", - r"^model(?!\.(language_model|visual))": "model.language_model", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index dbc476596a36..9ede499661e1 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -349,6 +349,12 @@ def attention_mask_padding_matches_padding_free_with_position_ids( tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + @unittest.skip( + reason="Conversion happens only with pre-saved ckpt on the hub. Model init from config doesn't rename any keys" + ) + def test_reverse_loading_mapping(self): + pass + @unittest.skip(reason="Feedforward chunking is not yet supported") def test_feed_forward_chunking(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6c78e5494a34..f1f8135ab452 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4674,7 +4674,8 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): # Skip if no conversions conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + # No conversion mapping for this model only, needs to test other classes + continue # Find the model keys, so the targets according to the conversions model_keys = list(model.state_dict().keys()) @@ -4743,7 +4744,8 @@ def test_can_load_from_already_mapped_keys(self): # Skip if no conversions conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + # No conversion mapping for this model only, needs to test other classes + continue with tempfile.TemporaryDirectory() as tmpdirname: # Serialize without reverting the mapping