diff --git a/docs/source/en/main_classes/backbones.md b/docs/source/en/main_classes/backbones.md index 3a0291bda898..6cc1b5034a54 100644 --- a/docs/source/en/main_classes/backbones.md +++ b/docs/source/en/main_classes/backbones.md @@ -21,7 +21,7 @@ A backbone is a model used for feature extraction for higher level computer visi * [`~backbone_utils.BackboneMixin`] enables initializing a backbone from Transformers or [timm](https://hf.co/docs/timm/index) and includes functions for returning the output features and indices. * [`~backbone_utils.BackboneConfigMixin`] sets the output features and indices of the backbone configuration. -[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmBackbone`] and [`TimmBackboneConfig`] classes. +[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmWrapperBackboneModel`] and [`TimmWrapperConfig`] classes. Backbones are supported for the following models: diff --git a/docs/source/en/model_doc/timm_wrapper.md b/docs/source/en/model_doc/timm_wrapper.md index c6e508b9ffe1..81cf6e04ba81 100644 --- a/docs/source/en/model_doc/timm_wrapper.md +++ b/docs/source/en/model_doc/timm_wrapper.md @@ -71,6 +71,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] TimmWrapperImageProcessor - preprocess +## TimmWrapperBackboneModel + +[[autodoc]] TimmWrapperBackboneModel + - forward + ## TimmWrapperModel [[autodoc]] TimmWrapperModel diff --git a/src/transformers/backbone_utils.py b/src/transformers/backbone_utils.py index 30a3a140b783..0cc9041dc280 100644 --- a/src/transformers/backbone_utils.py +++ b/src/transformers/backbone_utils.py @@ -297,7 +297,7 @@ def consolidate_backbone_kwargs_to_config( and backbone_config is None and not backbone_kwargs ): - backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **timm_default_kwargs) + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **timm_default_kwargs) elif backbone is not None and backbone_config is None: if repo_exists(backbone): config_dict, _ = PreTrainedConfig.get_config_dict(backbone) @@ -305,7 +305,14 @@ def consolidate_backbone_kwargs_to_config( config_dict.update(backbone_kwargs) backbone_config = config_class(**config_dict) else: - backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **backbone_kwargs) + # Move timm-args inside `model_args` to support loading from TimmBackboneConfig + if "model_args" not in backbone_kwargs: + backbone_kwargs["model_args"] = { + "in_chans": backbone_kwargs.pop("num_channels", 3), + "features_only": backbone_kwargs.pop("features_only", True), + "output_stride": backbone_kwargs.pop("output_stride", None), + } + backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **backbone_kwargs) elif backbone_config is None and default_config_type is not None: logger.info( f"`backbone_config` is `None`. Initializing the config with the default `{default_config_type}` vision config." @@ -314,28 +321,15 @@ def consolidate_backbone_kwargs_to_config( backbone_config = CONFIG_MAPPING[default_config_type](**default_config_kwargs) elif isinstance(backbone_config, dict): backbone_model_type = backbone_config.get("model_type") + if backbone_model_type == "timm_backbone": + backbone_model_type = "timm_wrapper" + # Move timm-args inside `model_args` + backbone_config["model_args"] = { + "in_chans": backbone_config.pop("num_channels", 3), + "features_only": backbone_config.pop("features_only", True), + "output_stride": backbone_config.pop("output_stride", None), + } config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) return backbone_config, kwargs - - -def load_backbone(config): - """ - Loads the backbone model from a config object. - - If the config is from the backbone model itself, then we return a backbone model with randomly initialized - weights. - - If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights - if specified. - """ - from transformers import AutoBackbone - - backbone_config = getattr(config, "backbone_config", None) - - if backbone_config is None: - backbone = AutoBackbone.from_config(config=config) - else: - backbone = AutoBackbone.from_config(config=backbone_config) - return backbone diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 03c056c9124e..ac00fe4a97cd 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -55,7 +55,6 @@ "qwen3_omni_moe": "qwen2_moe", "qwen3_omni_moe_thinker": "qwen2_moe", "qwen3_next": "qwen2_moe", - "qwen3_5_moe": "qwen2_moe", "hunyuan_v1_moe": "qwen2_moe", "flex_olmo": "qwen2_moe", "olmoe": "qwen2_moe", @@ -67,18 +66,26 @@ def _build_checkpoint_conversion_mapping(): mapping = { + "paligemma": [ + WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), + WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), + ], + "llava": [ + WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), + WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), + ], + "qwen2_vl": [ + WeightRenaming( + source_patterns=r"(^|\.)model(?!\.(language_model|visual))", target_patterns="model.language_model" + ), + ], "qwen3_5_text": [ WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), ], - "t5gemma2": [ - WeightRenaming(r"(? tuple[str, str | None]: # Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of # Qwen2.5, Sam3, Ernie4.5 VL MoE! pattern = re.sub(r"\(\?.+\)", "", pattern) + # Remove the backslash for literal dots + pattern = pattern.replace(r"\.", ".") # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) capturing_group_match = re.search(r"\(.+?\)", pattern) captured_group = None @@ -1257,10 +1259,8 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # In this case, the model was not created with `from_pretrained` -> let's check if it's in the hardcoded # mappings, and recreate the mapping from there if it is if weight_conversions is None: - from .conversion_mapping import get_model_conversion_mapping - # Do not resave with the legacy renaming, if present - weight_conversions = get_model_conversion_mapping(model, add_legacy=False) + weight_conversions = model.get_weight_conversions_recursively(add_legacy=False) weight_conversions = weight_conversions if len(weight_conversions) > 0 else None # We did not find any operations to perform -> quick escape diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index ff8734c510ad..01f4785943b9 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -22,7 +22,6 @@ from ..conversion_mapping import ( _MODEL_TO_CONVERSION_PATTERN, get_checkpoint_conversion_mapping, - get_model_conversion_mapping, ) from ..core_model_loading import ( Concatenate, @@ -519,7 +518,7 @@ def load_adapter( **load_config.download_kwargs, ) - weight_conversions = get_model_conversion_mapping(self) + weight_conversions = self.get_weight_conversions_recursively() peft_config = convert_peft_config_for_transformers(peft_config, model=self, conversions=weight_conversions) if hasattr(peft_config, "inference_mode"): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f91fa71814d8..33db328827a8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4066,8 +4066,8 @@ def from_pretrained( # instantiated model, as the flags can be modified by instances sometimes) dtype_plan = model._get_dtype_plan(dtype) - # Obtain the weight conversion mapping for this model if any are registered - weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer) + # Obtain the weight conversion mapping for this model if any are registered and appy to all submodels recursively + weight_conversions = model.get_weight_conversions_recursively(key_mapping, hf_quantizer) if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size) @@ -4265,6 +4265,20 @@ def _finalize_model_loading( return loading_info + def get_weight_conversions_recursively(self, key_mapping=None, hf_quantizer=None, add_legacy=True): + conversions = [] + conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer, add_legacy)) + + for submodule in self.children(): + if ( + submodule is not self + and isinstance(submodule, PreTrainedModel) + and submodule.config.__class__ != self.config.__class__ + ): + conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer, add_legacy)) + conversions.extend(submodule.get_weight_conversions_recursively(key_mapping, hf_quantizer, add_legacy)) + return conversions + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = {".".join(key.split(".")[:-1]) for key in names} diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index af9e0e569349..47f56954a6b5 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -21,8 +21,6 @@ from collections.abc import Iterator from typing import Any, TypeVar -from huggingface_hub import repo_exists - from ...configuration_utils import PreTrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...utils import ( @@ -247,8 +245,38 @@ def _prepare_config_for_auto_class(cls, config: PreTrainedConfig) -> PreTrainedC """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" return config + @classmethod + def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_wrapper import TimmWrapperConfig + + if kwargs.get("output_loading_info", False): + raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") + + # Users can't pass `config` and `kwargs`, choose only one! + config = kwargs.pop("config", None) + if config is None: + config = TimmWrapperConfig( + architecture=pretrained_model_name_or_path, + do_pooling=False, + out_indices=kwargs.pop("out_indices", (-1,)), + model_args={ + "in_chans": kwargs.pop("num_channels", 3), + "features_only": kwargs.pop("features_only", True), + }, + ) + + # Always load a pretrained model when `from_pretrained` is called + kwargs.pop("use_pretrained_backbone", None) + return cls.from_config(config, pretrained=True, **kwargs) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args, **kwargs): + # Early exit for `timm` models, they aren't hosted on the hub usually + use_timm_backbone = kwargs.pop("use_timm_backbone", None) + if use_timm_backbone: + return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) trust_remote_code = kwargs.get("trust_remote_code") kwargs["_from_auto"] = True @@ -399,45 +427,6 @@ def register(cls, config_class, model_class, exist_ok=False) -> None: cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) -class _BaseAutoBackboneClass(_BaseAutoModelClass): - # Base class for auto backbone models. - _model_mapping = None - - @classmethod - def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - requires_backends(cls, ["vision", "timm"]) - from ...models.timm_backbone import TimmBackboneConfig - - config = kwargs.pop("config", TimmBackboneConfig()) - - if kwargs.get("out_features") is not None: - raise ValueError("Cannot specify `out_features` for timm backbones") - - if kwargs.get("output_loading_info", False): - raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") - - num_channels = kwargs.pop("num_channels", config.num_channels) - features_only = kwargs.pop("features_only", config.features_only) - out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmBackboneConfig( - backbone=pretrained_model_name_or_path, - num_channels=num_channels, - features_only=features_only, - out_indices=out_indices, - ) - # Always load a pretrained model when `from_pretrained` is called - kwargs.pop("use_pretrained_backbone", None) - return super().from_config(config, pretrained=True, **kwargs) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - kwargs.pop("use_timm_backbone", None) - if not repo_exists(pretrained_model_name_or_path): - return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - def insert_head_doc(docstring, head_doc: str = ""): if len(head_doc) > 0: return docstring.replace( diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e5bdf65ea3c..40b0bbeebc20 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -445,7 +445,7 @@ ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), - ("timm_backbone", "TimmBackboneConfig"), + ("timm_backbone", "TimmBackboneConfig"), # for BC ("timm_wrapper", "TimmWrapperConfig"), ("trocr", "TrOCRConfig"), ("tvp", "TvpConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f0cb6b5b3fe7..3d50cbaea04a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -19,7 +19,6 @@ from ...utils import logging from .auto_factory import ( - _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update, @@ -427,7 +426,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("time_series_transformer", "TimeSeriesTransformerModel"), ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), - ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), ("tvp", "TvpModel"), ("udop", "UdopModel"), @@ -778,7 +776,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("swinv2", "Swinv2Model"), ("table-transformer", "TableTransformerModel"), ("timesformer", "TimesformerModel"), - ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), ("videomae", "VideoMAEModel"), ("vit", "ViTModel"), @@ -1647,7 +1644,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("swin", "SwinBackbone"), ("swinv2", "Swinv2Backbone"), ("textnet", "TextNetBackbone"), - ("timm_backbone", "TimmBackbone"), + ("timm_backbone", "TimmWrapperBackboneModel"), # for BC + ("timm_wrapper", "TimmWrapperBackboneModel"), ("vitdet", "VitDetBackbone"), ("vitpose_backbone", "VitPoseBackbone"), ] @@ -2161,7 +2159,7 @@ class AutoModelForTextToWaveform(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING -class AutoBackbone(_BaseAutoBackboneClass): +class AutoBackbone(_BaseAutoModelClass): _model_mapping = MODEL_FOR_BACKBONE_MAPPING diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index f862bcf943e8..967ec0ba603a 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -101,13 +101,6 @@ class ColPaliForRetrievalOutput(ModelOutput): """ ) class ColPaliForRetrieval(ColPaliPreTrainedModel): - _checkpoint_conversion_mapping = { - "vlm.language_model.model": "vlm.model.language_model", - "vlm.vision_tower": "vlm.model.vision_tower", - "vlm.multi_modal_projector": "vlm.model.multi_modal_projector", - "vlm.language_model.lm_head": "vlm.lm_head", - } - def __init__(self, config: ColPaliConfig): super().__init__(config) self.config = config diff --git a/src/transformers/models/colqwen2/configuration_colqwen2.py b/src/transformers/models/colqwen2/configuration_colqwen2.py index b168f502b3de..f1997a8882e5 100644 --- a/src/transformers/models/colqwen2/configuration_colqwen2.py +++ b/src/transformers/models/colqwen2/configuration_colqwen2.py @@ -61,6 +61,7 @@ def __init__( vlm_config=None, embedding_dim: int = 128, initializer_range: float = 0.02, + tie_word_embeddings: bool = True, **kwargs, ): if vlm_config is None: @@ -86,6 +87,7 @@ def __init__( self.vlm_config = vlm_config self.embedding_dim = embedding_dim self.initializer_range = initializer_range + self.tie_word_embeddings = tie_word_embeddings super().__init__(**kwargs) def get_text_config(self, *args, **kwargs) -> PreTrainedConfig: diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 5a1684571231..5c2180bdd236 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -105,8 +105,6 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): - _checkpoint_conversion_mapping = {} - def __init__(self, config: ColQwen2Config): super().__init__(config) self.config = config diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index b7a8da6364d2..c120859f2f94 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -257,8 +257,6 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColPaliForRetrieval): - _checkpoint_conversion_mapping = {} - def __init__(self, config: ColQwen2Config): super().__init__(config) del self._tied_weights_keys diff --git a/src/transformers/models/conditional_detr/configuration_conditional_detr.py b/src/transformers/models/conditional_detr/configuration_conditional_detr.py index a8ad94920057..de55f837d578 100644 --- a/src/transformers/models/conditional_detr/configuration_conditional_detr.py +++ b/src/transformers/models/conditional_detr/configuration_conditional_detr.py @@ -161,13 +161,15 @@ def __init__( # Init timm backbone with hardcoded values for BC backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 2de83de19c12..fe6a66c79b8c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -26,7 +26,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -36,6 +35,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from .configuration_conditional_detr import ConditionalDetrConfig @@ -258,7 +258,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm @@ -266,10 +266,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 1c758f8b1dcd..4614cab6f237 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2CLS -from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -37,6 +36,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_d_fine import DFineConfig @@ -1351,7 +1351,7 @@ class DFineConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/dab_detr/configuration_dab_detr.py b/src/transformers/models/dab_detr/configuration_dab_detr.py index f4e3d2e062ad..abc87297b063 100644 --- a/src/transformers/models/dab_detr/configuration_dab_detr.py +++ b/src/transformers/models/dab_detr/configuration_dab_detr.py @@ -175,13 +175,15 @@ def __init__( # Init timm backbone with hardcoded values for BC timm_default_kwargs = { - "num_channels": 3, - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": 3, + "features_only": True, + "pretrained": False, + }, "out_indices": [1, 2, 3, 4], } if dilation: - timm_default_kwargs["output_stride"] = 16 + timm_default_kwargs["model_args"]["output_stride"] = 16 backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 421bdbce6b89..a721563882db 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -21,7 +21,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -31,6 +30,7 @@ auto_docstring, logging, ) +from ..auto import AutoBackbone from .configuration_dab_detr import DabDetrConfig @@ -211,7 +211,7 @@ def __init__(self, config: DabDetrConfig): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/deformable_detr/configuration_deformable_detr.py b/src/transformers/models/deformable_detr/configuration_deformable_detr.py index ad7f56eb763a..78f57d132cb5 100644 --- a/src/transformers/models/deformable_detr/configuration_deformable_detr.py +++ b/src/transformers/models/deformable_detr/configuration_deformable_detr.py @@ -185,13 +185,15 @@ def __init__( ): # Init timm backbone with hardcoded values for BC timm_default_kwargs = { - "num_channels": 3, - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": 3, + "features_only": True, + "pretrained": False, + }, "out_indices": [2, 3, 4] if num_feature_levels > 1 else [4], } if dilation: - timm_default_kwargs["output_stride"] = 16 + timm_default_kwargs["model_args"]["output_stride"] = 16 backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 3ee685a887c1..e269c3c708b2 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -29,7 +29,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...integrations import use_kernel_forward_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions @@ -39,6 +38,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from .configuration_deformable_detr import DeformableDetrConfig @@ -298,7 +298,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm @@ -306,10 +306,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/deformable_detr/modular_deformable_detr.py b/src/transformers/models/deformable_detr/modular_deformable_detr.py index dfbc0783fb0a..118ea209ad54 100644 --- a/src/transformers/models/deformable_detr/modular_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modular_deformable_detr.py @@ -21,7 +21,6 @@ from torch import Tensor from ... import initialization as init -from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -37,6 +36,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoBackbone from ..detr.image_processing_detr_fast import DetrImageProcessorFast from ..detr.modeling_detr import ( DetrConvEncoder, @@ -292,7 +292,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm @@ -300,10 +300,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 16e1e3c0319c..7a53e57b95b6 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -16,10 +16,10 @@ import torch from torch import nn -from ...backbone_utils import load_backbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ..auto import AutoBackbone from .configuration_depth_anything import DepthAnythingConfig @@ -319,7 +319,7 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.neck = DepthAnythingNeck(config) self.head = DepthAnythingDepthEstimationHead(config) diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 87835ba098a2..4fc6728bcd90 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -157,13 +157,15 @@ def __init__( ): backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 4906b3510f44..d45427ba9c8a 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -22,7 +22,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -41,6 +40,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_detr import DetrConfig @@ -258,7 +258,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm @@ -266,10 +266,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index ac8b255bfecd..e9f96d148d76 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -36,6 +35,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_dpt import DPTConfig @@ -102,7 +102,7 @@ def __init__(self, config: DPTConfig, feature_size: tuple[int, int] | None = Non patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) feature_dim = self.backbone.channels[-1] if len(self.backbone.channels) != 3: raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}") @@ -925,7 +925,7 @@ def __init__(self, config): self.backbone = None if config.is_hybrid is False and config.backbone_config is not None: - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) else: self.dpt = DPTModel(config, add_pooling_layer=False) diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index 37e076f5861e..c8193622d5ad 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -23,11 +23,7 @@ from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_rope_utils import RopeParameters -from ...utils import is_timm_available, logging, requires_backends - - -if is_timm_available(): - from timm.data import ImageNetInfo, infer_imagenet_subset +from ...utils import logging logger = logging.get_logger(__name__) @@ -502,57 +498,6 @@ def __init__( self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) - @classmethod - def from_dict(cls, config_dict: dict[str, Any], **kwargs): - # Create a copy to avoid mutating the original dict - config_dict = config_dict.copy() - - label_names = config_dict.get("label_names") - is_custom_model = "num_labels" in kwargs or "id2label" in kwargs - - # if no labels added to config, use imagenet labeller in timm - if label_names is None and not is_custom_model: - requires_backends(cls, ["timm"]) - imagenet_subset = infer_imagenet_subset(config_dict) - if imagenet_subset: - dataset_info = ImageNetInfo(imagenet_subset) - synsets = dataset_info.label_names() - label_descriptions = dataset_info.label_descriptions(as_dict=True) - label_names = [label_descriptions[synset] for synset in synsets] - - if label_names is not None and not is_custom_model: - kwargs["id2label"] = dict(enumerate(label_names)) - - # if all label names are unique, create label2id mapping as well - if len(set(label_names)) == len(label_names): - kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} - else: - kwargs["label2id"] = None - - # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. - # We are removing these attributes in order to have the native `transformers` num_labels attribute in config - # and to avoid duplicate attributes - num_labels_in_kwargs = kwargs.pop("num_labels", None) - num_labels_in_dict = config_dict.pop("num_classes", None) - - # passed num_labels has priority over num_classes in config_dict - kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict - - # pop num_classes from "pretrained_cfg", - # it is not necessary to have it, only root one is used in timm - if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: - config_dict["pretrained_cfg"].pop("num_classes", None) - - return super().from_dict(config_dict, **kwargs) - - def to_dict(self) -> dict[str, Any]: - output = super().to_dict() - output.setdefault("num_classes", self.num_labels) - output.setdefault("label_names", list(self.id2label.values())) - output.pop("id2label", None) - output.pop("label2id", None) - return output - class Gemma3nConfig(PreTrainedConfig): r""" diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index e22def7b0d87..65ecd63f035b 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -38,13 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - torch_compilable_check, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a97cc2823c7b..d563a7e57d4e 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -56,7 +56,6 @@ PaliGemmaModel, PaligemmaModelOutputWithPast, ) -from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig logger = logging.get_logger(__name__) @@ -444,7 +443,7 @@ def __init__( self.sscp_conv_stride_size = sscp_conv_stride_size -class Gemma3nVisionConfig(TimmWrapperConfig): +class Gemma3nVisionConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to instantiate an timm model model according to the specified arguments, defining the model architecture. @@ -510,6 +509,10 @@ def __init__( self.vocab_size = vocab_size self.vocab_offset = vocab_offset self.rms_norm_eps = rms_norm_eps + self.architecture = architecture + self.initializer_range = initializer_range + self.do_pooling = do_pooling + self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 26952ada4894..2f53024db563 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -23,12 +23,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging, torch_compilable_check -from ..auto import AutoModel +from ..auto import AutoBackbone, AutoModel from .configuration_grounding_dino import GroundingDinoConfig @@ -368,7 +367,7 @@ def __init__(self, config): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index a52aaa1cda51..0e63987abbbc 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -128,10 +128,6 @@ class LlavaPreTrainedModel(PreTrainedModel): """ ) class LlavaModel(LlavaPreTrainedModel): - _checkpoint_conversion_mapping = { - r"^language_model.model": "language_model", - } - def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) @@ -282,12 +278,6 @@ def forward( """ ) class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - r"^language_model.model": "model.language_model", - r"^vision_tower": "model.vision_tower", - r"^multi_modal_projector": "model.multi_modal_projector", - r"^language_model.lm_head": "lm_head", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: LlavaConfig): diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 3d6a2acd0968..8e4f19d05cc9 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,13 +23,13 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import auto_docstring, is_accelerate_available, logging, torch_compilable_check +from ..auto import AutoBackbone from .configuration_mask2former import Mask2FormerConfig @@ -1399,7 +1399,7 @@ def __init__(self, config: Mask2FormerConfig): """ super().__init__() - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index cacb86b788ed..43250ef96983 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -14,6 +14,7 @@ """PyTorch MaskFormer model.""" import math +from collections.abc import Callable from dataclasses import dataclass from numbers import Number @@ -23,11 +24,10 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ( @@ -39,6 +39,9 @@ logging, requires_backends, ) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from ..detr import DetrConfig from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -387,206 +390,262 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = return loss / height_and_width -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrAttention -class DetrAttention(nn.Module): +# Copied from transformers.models.detr.modeling_detr.DetrMLP +class DetrMLP(nn.Module): + def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.detr.modeling_detr.DetrSelfAttention +class DetrSelfAttention(nn.Module): """ - Multi-headed attention from 'Attention Is All You Need' paper. + Multi-headed self-attention from 'Attention Is All You Need' paper. - Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + In DETR, position embeddings are added to both queries and keys (but not values) in self-attention. """ def __init__( self, - embed_dim: int, - num_heads: int, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, dropout: float = 0.0, bias: bool = True, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - if self.head_dim * num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {num_heads})." - ) + self.config = config + self.head_dim = hidden_size // num_attention_heads self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): - return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None): - return tensor if object_queries is None else tensor + object_queries + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - key_value_states: torch.Tensor | None = None, - spatial_position_embeddings: torch.Tensor | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size, target_len, embed_dim = hidden_states.size() - - # add position embeddings to the hidden states before projecting to queries and keys - if object_queries is not None: - hidden_states_original = hidden_states - hidden_states = self.with_pos_embed(hidden_states, object_queries) - - # add key-value position embeddings to the key value states - if spatial_position_embeddings is not None: - key_value_states_original = key_value_states - key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states - source_len = key_states.size(1) + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) - if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): - raise ValueError( - f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" - f" {attn_weights.size()}" - ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, target_len, source_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" - f" {attention_mask.size()}" - ) - if attention_mask.dtype == torch.bool: - attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( - attention_mask, -torch.inf - ) - attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) - else: - attn_weights_reshaped = None + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) +# Copied from transformers.models.detr.modeling_detr.DetrCrossAttention +class DetrCrossAttention(nn.Module): + """ + Multi-headed cross-attention from 'Attention Is All You Need' paper. - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + In DETR, queries get their own position embeddings, while keys get encoder position embeddings. + Values don't get any position embeddings. + """ - attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + def __init__( + self, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - attn_output = self.out_proj(attn_output) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) - return attn_output, attn_weights_reshaped + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + encoder_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings logic: + - Queries get position_embeddings + - Keys get encoder_position_embeddings + - Values don't get any position embeddings + """ + query_input_shape = hidden_states.shape[:-1] + query_hidden_shape = (*query_input_shape, -1, self.head_dim) + + kv_input_shape = key_value_states.shape[:-1] + kv_hidden_shape = (*kv_input_shape, -1, self.head_dim) + + query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + key_input = ( + key_value_states + encoder_position_embeddings + if encoder_position_embeddings is not None + else key_value_states + ) + query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2) + key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2) -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*query_input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() - self.embed_dim = config.d_model + self.hidden_size = config.d_model - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, + self.self_attn = DetrSelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = DetrAttention( - self.embed_dim, - config.decoder_attention_heads, + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.encoder_attn = DetrCrossAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - query_position_embeddings: torch.Tensor | None = None, + spatial_position_embeddings: torch.Tensor | None = None, + object_queries_position_embeddings: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = False, **kwargs: Unpack[TransformersKwargs], - ): + ) -> torch.Tensor: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - object_queries (`torch.FloatTensor`, *optional*): - object_queries that are added to the hidden states - in the cross-attention layer. - query_position_embeddings (`torch.FloatTensor`, *optional*): - position embeddings that are added to the queries and keys - in the self-attention layer. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only + in the cross-attention layer (not to values). + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. In self-attention, these are added to both queries + and keys (not values). In cross-attention, these are added to queries only (not to keys or values). encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. """ residual = hidden_states # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, + position_embeddings=object_queries_position_embeddings, attention_mask=attention_mask, - output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -594,17 +653,16 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - spatial_position_embeddings=object_queries, - output_attentions=output_attentions, + position_embeddings=object_queries_position_embeddings, + encoder_position_embeddings=spatial_position_embeddings, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -613,66 +671,52 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states +# copied from transformers.models.detr.modeling_detr.DetrDecoder with DetrPreTrainedModel->PreTrainedModel class DetrDecoder(PreTrainedModel): - config: DetrConfig - base_model_prefix = "model" - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. - - The decoder updates the query embeddings through multiple self-attention and cross-attention layers. - - Some small tweaks for DETR: - - - object_queries and query_position_embeddings are added to the forward pass. - - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules, + which apply self-attention to the queries and cross-attention to the encoder's outputs. Args: - config: DetrConfig + config (`DetrConfig`): Model configuration object. """ + _can_record_outputs = { + "hidden_states": DetrDecoderLayer, + "attentions": DetrSelfAttention, + "cross_attentions": DetrCrossAttention, + } + def __init__(self, config: DetrConfig): super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) # in DETR, the decoder uses layernorm after the last decoder layer output self.layernorm = nn.LayerNorm(config.d_model) - self.gradient_checkpointing = False - + # Initialize weights and apply final processing self.post_init() + @merge_with_config_defaults + @capture_outputs def forward( self, inputs_embeds=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - object_queries=None, - query_position_embeddings=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + spatial_position_embeddings=None, + object_queries_position_embeddings=None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> DetrDecoderOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -695,99 +739,59 @@ def forward( - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Position embeddings that are added to the queries and keys in each cross-attention layer. - query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer. + object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds - encoder_attention_mask = create_bidirectional_mask( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - ) + if attention_mask is not None: + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + ) + + # expand encoder attention mask (for cross-attention on encoder outputs) + if encoder_hidden_states is not None and encoder_attention_mask is not None: + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, - None, # attention_mask - object_queries, - query_position_embeddings, + attention_mask, + spatial_position_embeddings, + object_queries_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] - if self.config.auxiliary_loss: hidden_states = self.layernorm(hidden_states) intermediate += (hidden_states,) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - # finally, apply layernorm hidden_states = self.layernorm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - # stack intermediate decoder activations if self.config.auxiliary_loss: intermediate = torch.stack(intermediate) - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] - if v is not None - ) - return DetrDecoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - intermediate_hidden_states=intermediate, - ) + return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate) # refactored from original implementation @@ -1352,7 +1356,7 @@ def __init__(self, config: MaskFormerConfig): backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] config.backbone_config = backbone_config - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) feature_channels = self.encoder.channels self.decoder = MaskFormerPixelDecoder( @@ -1428,8 +1432,8 @@ def forward( attention_mask=None, encoder_hidden_states=image_features, encoder_attention_mask=None, - object_queries=object_queries, - query_position_embeddings=queries_embeddings, + spatial_position_embeddings=object_queries, + object_queries_position_embeddings=queries_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 6acaccf70023..18e8404f0231 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -27,11 +27,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...file_utils import ModelOutput from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, torch_compilable_check +from ..auto import AutoBackbone from ..auto.modeling_auto import AutoModel from .configuration_mm_grounding_dino import MMGroundingDinoConfig @@ -646,7 +646,7 @@ def __init__(self, config): super().__init__() self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py index e662f318a69b..a00ca746ce0a 100644 --- a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py @@ -187,9 +187,13 @@ def __init__( ): # Init timm backbone with hardcoded values for BC timm_default_kwargs = { + "model_args": { + "img_size": image_size, + "features_only": True, + "pretrained": False, + "always_partition": True, + }, "out_indices": [1, 2, 3], - "img_size": image_size, - "always_partition": True, } backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py index 79aaddc2a8ba..093337f4ec72 100644 --- a/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py +++ b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py @@ -69,7 +69,7 @@ def create_rename_keys_vision(state_dict, config): for layer_name in state_dict: if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"): if config.use_timm_backbone: - layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone") + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone.timm_model") layer_name_replace = layer_name_replace.replace(".layers.", ".layers_") if "downsample" in layer_name: # get layer number diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index a30b52897051..fb172a009642 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -26,7 +26,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -39,7 +38,7 @@ logging, torch_compilable_check, ) -from ..auto import AutoModel +from ..auto import AutoBackbone, AutoModel from .configuration_omdet_turbo import OmDetTurboConfig @@ -279,7 +278,7 @@ class OmDetTurboVisionBackbone(nn.Module): def __init__(self, config: OmDetTurboConfig): super().__init__() self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone - self.vision_backbone = load_backbone(config) + self.vision_backbone = AutoBackbone.from_config(config=config.backbone_config) self.layer_norms = nn.ModuleList( [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels] ) diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 2b4433e52fa9..feb959e86b14 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -24,7 +24,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -39,6 +38,7 @@ torch_compilable_check, ) from ...utils.generic import maybe_autocast +from ..auto import AutoBackbone from .configuration_oneformer import OneFormerConfig @@ -1456,7 +1456,7 @@ def __init__(self, config: OneFormerConfig): The configuration used to instantiate this model. """ super().__init__() - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 15f2071ee2bc..d5465e61763e 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -236,7 +236,6 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): """ ) class PaliGemmaModel(PaliGemmaPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False @@ -428,12 +427,6 @@ def forward( """ ) class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: PaliGemmaConfig): diff --git a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py index d4c275b93eed..e5ee32f818bf 100644 --- a/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +++ b/src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py @@ -30,7 +30,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -46,6 +45,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_pp_doclayout_v3 import PPDocLayoutV3Config @@ -1322,7 +1322,7 @@ class PPDocLayoutV3ConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 5461d87c609a..bca80f173dd0 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn -from ...backbone_utils import load_backbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring from ...utils.generic import torch_int +from ..auto import AutoBackbone from .configuration_prompt_depth_anything import PromptDepthAnythingConfig @@ -375,7 +375,7 @@ class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.neck = PromptDepthAnythingNeck(config) self.head = PromptDepthAnythingDepthEstimationHead(config) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 3a403942fc18..d5cd0cb13c5e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -972,7 +972,6 @@ def forward( @auto_docstring class Qwen2VLModel(Qwen2VLPreTrainedModel): base_model_prefix = "model" - _checkpoint_conversion_mapping = {"^model": "language_model"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False @@ -1370,10 +1369,6 @@ def forward( class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^visual": "model.visual", - r"^model(?!\.(language_model|visual))": "model.language_model", - } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config): diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 182d4b2c054a..ec9aa7fbe7f3 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -28,7 +28,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput @@ -38,6 +37,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_rt_detr import RTDetrConfig @@ -400,7 +400,7 @@ class RTDetrConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index f9289f9e6619..28f33ddf55a1 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -23,7 +23,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width from ...image_transforms import center_to_corners_format, corners_to_center_format @@ -52,6 +51,7 @@ ) from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from ..conditional_detr.modeling_conditional_detr import inverse_sigmoid from ..deformable_detr.modeling_deformable_detr import DeformableDetrMultiscaleDeformableAttention from ..detr.image_processing_detr_fast import DetrImageProcessorFast @@ -689,7 +689,7 @@ class RTDetrConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index b5244ffda7f8..3a31e82040a7 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -29,7 +29,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -38,6 +37,7 @@ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone from .configuration_rt_detr_v2 import RTDetrV2Config @@ -796,7 +796,7 @@ class RTDetrV2ConvEncoder(nn.Module): def __init__(self, config): super().__init__() - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) if config.freeze_backbone_batch_norms: # replace batch norm by frozen batch norm diff --git a/src/transformers/models/table_transformer/configuration_table_transformer.py b/src/transformers/models/table_transformer/configuration_table_transformer.py index f8c3b2e79320..2d293f22b5f1 100644 --- a/src/transformers/models/table_transformer/configuration_table_transformer.py +++ b/src/transformers/models/table_transformer/configuration_table_transformer.py @@ -158,13 +158,15 @@ def __init__( ): backbone_kwargs = kwargs.get("backbone_kwargs", {}) timm_default_kwargs = { - "num_channels": backbone_kwargs.get("num_channels", num_channels), - "features_only": True, - "use_pretrained_backbone": False, + "model_args": { + "in_chans": backbone_kwargs.get("num_channels", num_channels), + "features_only": True, + "pretrained": False, + }, "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), } if dilation: - timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + timm_default_kwargs["model_args"]["output_stride"] = backbone_kwargs.get("output_stride", 16) backbone_config, kwargs = consolidate_backbone_kwargs_to_config( backbone_config=backbone_config, diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 68c66ce8248b..ccc203559807 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -21,7 +21,6 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput @@ -31,6 +30,7 @@ auto_docstring, logging, ) +from ..auto import AutoBackbone from .configuration_table_transformer import TableTransformerConfig @@ -204,7 +204,7 @@ def __init__(self, config): self.config = config - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config=config.backbone_config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm @@ -212,10 +212,10 @@ def __init__(self, config): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API - # so we need to unwrap the `backbone._backbone` module to load weights without mismatch + # so we need to unwrap the `backbone.timm_model` module to load weights without mismatch is_timm_model = False - if hasattr(backbone, "_backbone"): - backbone = backbone._backbone + if hasattr(backbone, "timm_model"): + backbone = backbone.timm_model is_timm_model = True self.model = backbone diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py index ba5ed4eb14bf..6f63b4bfcc4c 100644 --- a/src/transformers/models/timm_backbone/configuration_timm_backbone.py +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -14,15 +14,14 @@ """Configuration for Backbone models""" -from ...backbone_utils import BackboneConfigMixin -from ...configuration_utils import PreTrainedConfig from ...utils import logging +from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig logger = logging.get_logger(__name__) -class TimmBackboneConfig(BackboneConfigMixin, PreTrainedConfig): +class TimmBackboneConfig(TimmWrapperConfig): r""" This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`]. @@ -31,6 +30,8 @@ class TimmBackboneConfig(BackboneConfigMixin, PreTrainedConfig): Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. + Note that the config class is deprecated, use `TimmWrapperConfig` instead! + Args: backbone (`str`, *optional*): The timm checkpoint to load. @@ -60,6 +61,12 @@ class TimmBackboneConfig(BackboneConfigMixin, PreTrainedConfig): """ model_type = "timm_backbone" + special_attribute_map = { + "backbone": "architecture", + "num_channels": ("model_args", "in_chans"), + "output_stride": ("model_args", "output_stride"), + "features_only": ("model_args", "features_only"), + } def __init__( self, @@ -71,40 +78,38 @@ def __init__( output_stride=None, **kwargs, ): - self.backbone = backbone - self.num_channels = num_channels - self.features_only = features_only self.out_indices = out_indices if out_indices is not None else [-1] - self.output_stride = output_stride - self.freeze_batch_norm_2d = freeze_batch_norm_2d - - super().__init__(**kwargs) - - @property - def out_indices(self): - return self._out_indices - - @out_indices.setter - def out_indices(self, out_indices: tuple[int, ...] | list[int]): - """ - Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. - """ - self._out_indices = list(out_indices) if out_indices is not None else out_indices - if getattr(self, "stage_names", None) is not None: - self.set_output_features_output_indices(out_features=None, out_indices=out_indices) - - @property - def out_features(self): - return self._out_features - - @out_features.setter - def out_features(self, out_features: list[str]): - """ - Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. - """ - self._out_features = out_features - if getattr(self, "stage_names", None) is not None: - self.set_output_features_output_indices(out_features=out_features, out_indices=None) + kwargs["architecture"] = backbone if backbone is not None else kwargs.get("architecture") + kwargs["do_pooling"] = False # hardcode for backbone model + if kwargs.get("model_args") is None: + kwargs["model_args"] = { + "features_only": features_only, + "in_chans": num_channels, + "output_stride": output_stride, + } + logger.warning( + "TimmBackboneConfig is deprecate and will be removed in future versions. " + "Use a TimmWrapperConfig instead with TimmWrapperBackboneModel to extract features." + ) + super().__init__(freeze_batch_norm_2d=freeze_batch_norm_2d, **kwargs) + + def __setattr__(self, key, value): + if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: + if isinstance(mapped_key, (tuple, list)): + model_args = super().__getattribute__("__dict__").get(mapped_key[0]) + model_args[mapped_key[1]] = value + else: + setattr(self, mapped_key[1], value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key): + if (mapped_key := super().__getattribute__("special_attribute_map").get(key)) is not None: + if isinstance(mapped_key, (tuple, list)): + model_args = super().__getattribute__(mapped_key[0]) + return model_args[mapped_key[1]] + return getattr(self, mapped_key[1]) + return super().__getattribute__(key) __all__ = ["TimmBackboneConfig"] diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index d46796d6f4cd..973d8e7b7b41 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -13,146 +13,26 @@ # limitations under the License. -import torch -from torch import Tensor, nn +from ...utils import logging +from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperBackboneModel -from ... import initialization as init -from ...backbone_utils import BackboneMixin -from ...modeling_outputs import BackboneOutput -from ...modeling_utils import PreTrainedModel -from ...utils import is_timm_available, requires_backends -from .configuration_timm_backbone import TimmBackboneConfig +logger = logging.get_logger(__name__) -if is_timm_available(): - import timm - -class TimmBackbone(BackboneMixin, PreTrainedModel): +class TimmBackbone(TimmWrapperBackboneModel): """ Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the other models in the library keeping the same API. + Note that the model is deprecated, use `TimmWrapperBackboneModel` instead! """ - main_input_name = "pixel_values" - input_modalities = ("image",) - supports_gradient_checkpointing = False - config: TimmBackboneConfig - - def __init__(self, config, **kwargs): - requires_backends(self, "timm") - - if config.backbone is None: - raise ValueError("backbone is not set in the config. Please set it to a timm model name.") - - # We just take the final layer by default. This matches the default for the transformers models. - out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) - pretrained = kwargs.pop("pretrained", False) - in_chans = kwargs.pop("in_chans", config.num_channels) - - backbone = timm.create_model( - config.backbone, - pretrained=pretrained, - # This is currently not possible for transformer architectures. - features_only=config.features_only, - in_chans=in_chans, - out_indices=out_indices, - output_stride=config.output_stride, - **kwargs, - ) - - # Needs to be called after creating timm model, because `super()` will try to infer - # `stage_names` from model architecture - super().__init__(config, timm_backbone=backbone) - self._backbone = backbone - - # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of - # provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively - if getattr(config, "freeze_batch_norm_2d", False): - self.freeze_batch_norm_2d() - - # These are used to control the output of the model when called. If output_hidden_states is True, then - # return_layers is modified to include all layers. - self._return_layers = { - layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts() - } - self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} - - self.post_init() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - requires_backends(cls, ["vision", "timm"]) - - config = kwargs.pop("config", TimmBackboneConfig()) - num_channels = kwargs.pop("num_channels", config.num_channels) - features_only = kwargs.pop("features_only", config.features_only) - out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmBackboneConfig( - backbone=pretrained_model_name_or_path, - num_channels=num_channels, - features_only=features_only, - out_indices=out_indices, - ) - return super()._from_config(config, pretrained=True, **kwargs) - - def freeze_batch_norm_2d(self): - timm.utils.model.freeze_batch_norm_2d(self._backbone) - - def unfreeze_batch_norm_2d(self): - timm.utils.model.unfreeze_batch_norm_2d(self._backbone) - - @torch.no_grad() - def _init_weights(self, module): - """We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we - assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)""" - if hasattr(module, "init_non_persistent_buffers"): - module.init_non_persistent_buffers() - elif isinstance(module, nn.BatchNorm2d): - # For non-pretrained models, always initialize buffers (handles both meta device and to_empty() cases) - running_mean = getattr(module, "running_mean", None) - if running_mean is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - - def forward( - self, - pixel_values: torch.FloatTensor, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> BackboneOutput | tuple[Tensor, ...]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + def __init__(self, *args, **kwargs): + logger.warning( + "`TimmBackbone` is deprecate and will be removed in future versions. Use a " + "`TimmWrapperBackboneModel` init from `TimmWrapperConfig` instead to extract features." ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - - if output_attentions: - raise ValueError("Cannot output attentions for timm backbones at the moment") - - if output_hidden_states: - # We modify the return layers to include all the stages of the backbone - self._backbone.return_layers = self._all_layers - hidden_states = self._backbone(pixel_values, **kwargs) - self._backbone.return_layers = self._return_layers - feature_maps = tuple(hidden_states[i] for i in self.out_indices) - else: - feature_maps = self._backbone(pixel_values, **kwargs) - hidden_states = None - - feature_maps = tuple(feature_maps) - hidden_states = tuple(hidden_states) if hidden_states is not None else None - - if not return_dict: - output = (feature_maps,) - if output_hidden_states: - output = output + (hidden_states,) - return output - - return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) + super().__init__(*args, **kwargs) __all__ = ["TimmBackbone"] diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 5afd27331d6b..82b368101a85 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -16,18 +16,15 @@ from typing import Any +from ...backbone_utils import BackboneConfigMixin from ...configuration_utils import PreTrainedConfig -from ...utils import is_timm_available, logging, requires_backends - - -if is_timm_available(): - from timm.data import ImageNetInfo, infer_imagenet_subset +from ...utils import logging logger = logging.get_logger(__name__) -class TimmWrapperConfig(PreTrainedConfig): +class TimmWrapperConfig(PreTrainedConfig, BackboneConfigMixin): r""" This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. @@ -46,6 +43,8 @@ class TimmWrapperConfig(PreTrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. do_pooling (`bool`, *optional*, defaults to `True`): Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + freeze_batch_norm_2d (`bool`, *optional*, defaults to `False`): + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. model_args (`dict[str, Any]`, *optional*): Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}` for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`. @@ -69,12 +68,14 @@ def __init__( architecture: str = "resnet50", initializer_range: float = 0.02, do_pooling: bool = True, + freeze_batch_norm_2d=False, model_args: dict[str, Any] | None = None, **kwargs, ): self.architecture = architecture self.initializer_range = initializer_range self.do_pooling = do_pooling + self.freeze_batch_norm_2d = freeze_batch_norm_2d self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) @@ -85,17 +86,6 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): label_names = config_dict.get("label_names") is_custom_model = "num_labels" in kwargs or "id2label" in kwargs - - # if no labels added to config, use imagenet labeller in timm - if label_names is None and not is_custom_model: - requires_backends(cls, ["timm"]) - imagenet_subset = infer_imagenet_subset(config_dict) - if imagenet_subset: - dataset_info = ImageNetInfo(imagenet_subset) - synsets = dataset_info.label_names() - label_descriptions = dataset_info.label_descriptions(as_dict=True) - label_names = [label_descriptions[synset] for synset in synsets] - if label_names is not None and not is_custom_model: kwargs["id2label"] = dict(enumerate(label_names)) @@ -107,17 +97,13 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. # We are removing these attributes in order to have the native `transformers` num_labels attribute in config - # and to avoid duplicate attributes - num_labels_in_kwargs = kwargs.pop("num_labels", None) - num_labels_in_dict = config_dict.pop("num_classes", None) - - # passed num_labels has priority over num_classes in config_dict - kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + # and to avoid duplicate attributes. Noe that `num_labels` has priority over `num_classes` in config_dict + if config_dict.get("num_classes") is not None and kwargs.get("num_labels") is None: + kwargs["num_labels"] = config_dict.pop("num_classes", None) - # pop num_classes from "pretrained_cfg", - # it is not necessary to have it, only root one is used in timm - if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: - config_dict["pretrained_cfg"].pop("num_classes", None) + # Pop in nested `pretrained_cfg` as well + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) return super().from_dict(config_dict, **kwargs) @@ -129,5 +115,31 @@ def to_dict(self) -> dict[str, Any]: output.pop("label2id", None) return output + @property + def out_indices(self): + return self._out_indices + + @out_indices.setter + def out_indices(self, out_indices: tuple[int, ...] | list[int]): + """ + Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. + """ + self._out_indices = list(out_indices) if out_indices is not None else out_indices + if getattr(self, "stage_names", None) is not None: + self.set_output_features_output_indices(out_features=None, out_indices=out_indices) + + @property + def out_features(self): + return self._out_features + + @out_features.setter + def out_features(self, out_features: list[str]): + """ + Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. + """ + self._out_features = out_features + if getattr(self, "stage_names", None) is not None: + self.set_output_features_output_indices(out_features=out_features, out_indices=None) + __all__ = ["TimmWrapperConfig"] diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 6fcfec1389d5..149debf17e0a 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -18,15 +18,18 @@ from torch import Tensor, nn from ... import initialization as init -from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...backbone_utils import BackboneMixin +from ...modeling_outputs import BackboneOutput, ImageClassifierOutput, ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_timm_available, requires_backends +from ...utils import auto_docstring, can_return_tuple, is_timm_available, logging, requires_backends from .configuration_timm_wrapper import TimmWrapperConfig if is_timm_available(): import timm +logger = logging.get_logger(__name__) + @dataclass @auto_docstring( @@ -91,6 +94,12 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): def post_init(self): self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing() + + # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of + # provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively + if getattr(self.config, "freeze_batch_norm_2d", False): + self.freeze_batch_norm_2d() + super().post_init() def load_state_dict(self, state_dict, *args, **kwargs): @@ -142,6 +151,12 @@ def _timm_model_supports_gradient_checkpointing(self): def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs): self.timm_model.set_grad_checkpointing(enable) + def freeze_batch_norm_2d(self): + timm.utils.model.freeze_batch_norm_2d(self.timm_model) + + def unfreeze_batch_norm_2d(self): + timm.utils.model.unfreeze_batch_norm_2d(self.timm_model) + def get_input_embeddings(self): # TIMM backbones operate directly on images and do not expose token embeddings. return None @@ -150,6 +165,76 @@ def set_input_embeddings(self, value): raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.") +class TimmWrapperBackboneModel(BackboneMixin, TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the + other models in the library keeping the same API. + """ + + def __init__(self, config, **kwargs): + requires_backends(self, ["vision", "timm"]) + + extra_init_kwargs = config.model_args or {} + self.features_only = extra_init_kwargs.get("features_only", True) + + # We just take the final layer by default. This matches the default for the transformers models. + out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) + timm_backbone = _create_timm_model_with_error_handling(config, out_indices=out_indices, **extra_init_kwargs) + + # Needs to be called after creating timm model, because `super()` will try to infer + # `stage_names` from model architecture + super().__init__(config, timm_backbone=timm_backbone) + self.timm_model = timm_backbone + + # These are used to control the output of the model when called. If output_hidden_states is True, then + # return_layers is modified to include all layers. + self._return_layers = { + layer["module"]: str(layer["index"]) for layer in self.timm_model.feature_info.get_dicts() + } + self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self.timm_model.feature_info.info)} + + self.post_init() + + @property + def _backbone(self): + logger.warning( + f"The `self._backbone` attribute is deprecated for {self.__class__.__name__}. Please use `self.timm_model` instead." + ) + return self.timm_model + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BackboneOutput | tuple[Tensor, ...]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot output attentions for timm backbones at the moment") + + if output_hidden_states: + # We modify the return layers to include all the stages of the backbone + self.timm_model.return_layers = self._all_layers + hidden_states = self.timm_model(pixel_values) + self.timm_model.return_layers = self._return_layers + feature_maps = tuple(hidden_states[i] for i in self.out_indices) + else: + feature_maps = self.timm_model(pixel_values) + hidden_states = None + + feature_maps = tuple(feature_maps) + hidden_states = tuple(hidden_states) if hidden_states is not None else None + + return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) + + class TimmWrapperModel(TimmWrapperPreTrainedModel): """ Wrapper class for timm models to be used in transformers. @@ -164,22 +249,18 @@ def __init__(self, config: TimmWrapperConfig): self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() + @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, output_attentions: bool | None = None, output_hidden_states: bool | list[int] | None = None, - return_dict: bool | None = None, do_pooling: bool | None = None, use_cache: bool | None = None, **kwargs, ) -> TimmWrapperModelOutput | tuple[Tensor, ...]: r""" - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. do_pooling (`bool`, *optional*): Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the `do_pooling` value from the config is used. @@ -215,7 +296,6 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state ``` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -236,7 +316,11 @@ def forward( pixel_values = pixel_values.to(self.device) if self.features_only: - last_hidden_state = self.timm_model.forward(pixel_values, **kwargs) + logger.warning( + f"Using a `features_only` mode in {self.__class__.__name__} is deprecated and will be removed in v5.20.0" + "Instead please use `TimmWrapperBackboneModel` to obtain feature maps." + ) + last_hidden_state = self.timm_model.forward(pixel_values) hidden_states = last_hidden_state if output_hidden_states else None pooler_output = None else: @@ -255,11 +339,6 @@ def forward( else: pooler_output = None - if not return_dict: - outputs = (last_hidden_state, pooler_output, hidden_states) - outputs = tuple(output for output in outputs if output is not None) - return outputs - return TimmWrapperModelOutput( last_hidden_state=last_hidden_state, pooler_output=pooler_output, @@ -290,6 +369,7 @@ def __init__(self, config: TimmWrapperConfig): self.num_labels = config.num_labels self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -297,7 +377,6 @@ def forward( labels: torch.LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | list[int] | None = None, - return_dict: bool | None = None, **kwargs, ) -> ImageClassifierOutput | tuple[Tensor, ...]: r""" @@ -305,14 +384,6 @@ def forward( Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - **kwargs: - Additional keyword arguments passed along to the `timm` model forward. Examples: ```python @@ -342,7 +413,6 @@ def forward( >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) ``` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -375,11 +445,6 @@ def forward( if labels is not None: loss = self.loss_function(labels, logits, self.config) - if not return_dict: - outputs = (loss, logits, hidden_states) - outputs = tuple(output for output in outputs if output is not None) - return outputs - return ImageClassifierOutput( loss=loss, logits=logits, @@ -387,4 +452,9 @@ def forward( ) -__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"] +__all__ = [ + "TimmWrapperBackboneModel", + "TimmWrapperPreTrainedModel", + "TimmWrapperModel", + "TimmWrapperForImageClassification", +] diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 917556c31fff..801da7b399d9 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -21,11 +21,11 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ..auto import AutoBackbone from .configuration_tvp import TvpConfig @@ -135,7 +135,7 @@ def forward(self, logits, labels): class TvpVisionModel(nn.Module): def __init__(self, config): super().__init__() - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) if config.backbone_config is not None: in_channels = config.backbone_config.hidden_sizes[-1] diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index bf497134646f..61db23ce7c63 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -17,10 +17,10 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ...backbone_utils import load_backbone from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring +from ..auto import AutoBackbone from .configuration_upernet import UperNetConfig @@ -279,7 +279,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) # Semantic segmentation head(s) self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 658d90e8aa85..7d05f1e5d3a6 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -19,9 +19,9 @@ from torch import nn from ... import initialization as init -from ...backbone_utils import load_backbone from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring +from ..auto import AutoBackbone from .configuration_vitmatte import VitMatteConfig @@ -107,7 +107,11 @@ def __init__(self, config): # to enable loading HF backbone models. in_channels = 4 if config.backbone_config is not None: - in_channels = config.backbone_config.num_channels + in_channels = ( + config.backbone_config.num_channels + if hasattr(config.backbone_config, "num_channels") + else config.backbone_config.model_args["in_chans"] + ) out_channels = config.convstream_hidden_sizes @@ -222,7 +226,7 @@ def __init__(self, config): super().__init__(config) self.config = config - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) self.decoder = VitMatteDetailCaptureModule(config) # Initialize weights and apply final processing diff --git a/src/transformers/models/vitpose/modeling_vitpose.py b/src/transformers/models/vitpose/modeling_vitpose.py index 3e47f66bd03e..f077a2e346c7 100644 --- a/src/transformers/models/vitpose/modeling_vitpose.py +++ b/src/transformers/models/vitpose/modeling_vitpose.py @@ -19,12 +19,12 @@ from torch import nn from ... import initialization as init -from ...backbone_utils import load_backbone from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging from ...utils.generic import can_return_tuple +from ..auto import AutoBackbone from .configuration_vitpose import VitPoseConfig @@ -191,7 +191,7 @@ class VitPoseForPoseEstimation(VitPosePreTrainedModel): def __init__(self, config: VitPoseConfig): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) # add backbone attributes if not hasattr(self.backbone.config, "hidden_size"): diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index d385ca4080c2..399e677fa696 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -21,10 +21,10 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ..auto import AutoBackbone from .configuration_zoedepth import ZoeDepthConfig @@ -1226,7 +1226,7 @@ class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config=config.backbone_config) if hasattr(self.backbone.config, "hidden_size") and hasattr(self.backbone.config, "patch_size"): config.backbone_hidden_size = self.backbone.config.hidden_size diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index dbc476596a36..9ede499661e1 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -349,6 +349,12 @@ def attention_mask_padding_matches_padding_free_with_position_ids( tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + @unittest.skip( + reason="Conversion happens only with pre-saved ckpt on the hub. Model init from config doesn't rename any keys" + ) + def test_reverse_loading_mapping(self): + pass + @unittest.skip(reason="Feedforward chunking is not yet supported") def test_feed_forward_chunking(self): pass diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index fa8b24a40aa7..5410b0e7298c 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -42,7 +42,7 @@ def __init__( batch_size=3, image_size=32, num_channels=3, - is_training=True, + is_training=False, ): self.parent = parent self.out_indices = out_indices if out_indices is not None else [4] @@ -91,7 +91,7 @@ def setUp(self): self.config_class = TimmBackboneConfig self.model_tester = TimmBackboneModelTester(self) self.config_tester = ConfigTester( - self, config_class=self.config_class, has_text_modality=False, common_properties=["num_channels"] + self, config_class=self.config_class, has_text_modality=False, common_properties=["architecture"] ) def test_config(self): diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 46ec5c01fe0f..820e614da045 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -28,6 +28,7 @@ ) from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -36,7 +37,12 @@ if is_torch_available(): import torch - from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel + from transformers import ( + TimmWrapperBackboneModel, + TimmWrapperConfig, + TimmWrapperForImageClassification, + TimmWrapperModel, + ) if is_timm_available(): @@ -187,13 +193,14 @@ def test_do_pooling_option(self): self.assertIsNotNone(output.pooler_output) def test_timm_config_labels(self): - # test timm config with no labels + # test timm config with no labels, default labels are created for 100 classes checkpoint = "timm/resnet18.a1_in1k" config = TimmWrapperConfig.from_pretrained(checkpoint) - self.assertIsNone(config.label2id) + self.assertIsInstance(config.label2id, dict) self.assertIsInstance(config.id2label, dict) + self.assertEqual(len(config.id2label), len(config.label2id)) self.assertEqual(len(config.id2label), 1000) - self.assertEqual(config.id2label[1], "goldfish, Carassius auratus") + self.assertEqual(config.id2label[1], "LABEL_1") # test timm config with labels in config checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21" @@ -253,6 +260,32 @@ def test_model_init_args(self): self.assertEqual(len(restored_model.timm_model.blocks), 3) +class TimmWrapperBackboneModelTester(TimmWrapperModelTester): + def __init__(self, parent, **kwargs): + super().__init__(parent, **kwargs) + # Add `features_only` for the backbone model + self.model_args = {"channels": (16, 16, 16, 16), "features_only": True} + + +@require_torch +class TimmWrapperBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (TimmWrapperBackboneModel,) if is_torch_available() else () + has_attentions = False + config_class = TimmWrapperConfig + + def setUp(self): + self.model_tester = TimmWrapperBackboneModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=self.config_class, + has_text_modality=False, + common_properties=["architecture"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + # We will verify our results on an image of cute cats def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf876c0ae813..f1f8135ab452 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -43,7 +43,6 @@ logging, set_seed, ) -from transformers.conversion_mapping import get_model_conversion_mapping from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( @@ -4673,9 +4672,10 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): with self.subTest(model_class.__name__): model = model_class(copy.deepcopy(config)) # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) + conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + # No conversion mapping for this model only, needs to test other classes + continue # Find the model keys, so the targets according to the conversions model_keys = list(model.state_dict().keys()) @@ -4707,7 +4707,7 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): self.assertTrue( num_matches > 0, f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " - "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly", + "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", ) # If everything is still good at this point, let's test that we perform the same operations both when @@ -4742,9 +4742,10 @@ def test_can_load_from_already_mapped_keys(self): model = model_class(copy.deepcopy(config)) # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) + conversions = model.get_weight_conversions_recursively(add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + # No conversion mapping for this model only, needs to test other classes + continue with tempfile.TemporaryDirectory() as tmpdirname: # Serialize without reverting the mapping diff --git a/tests/utils/test_backbone_utils.py b/tests/utils/test_backbone_utils.py index a27ced73018f..906986cd5d72 100644 --- a/tests/utils/test_backbone_utils.py +++ b/tests/utils/test_backbone_utils.py @@ -16,11 +16,18 @@ import pytest -from transformers import DetrConfig, MaskFormerConfig, PreTrainedConfig, ResNetBackbone, ResNetConfig, TimmBackbone +from transformers import ( + AutoBackbone, + DetrConfig, + MaskFormerConfig, + PreTrainedConfig, + ResNetBackbone, + ResNetConfig, + TimmBackbone, +) from transformers.backbone_utils import ( BackboneConfigMixin, BackboneMixin, - load_backbone, ) from transformers.testing_utils import require_torch, slow from transformers.utils.import_utils import is_torch_available @@ -160,7 +167,7 @@ def test_load_backbone_from_config(self): Test that load_backbone correctly loads a backbone from a backbone config. """ config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2))) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_features, ["stem", "stage2"]) self.assertEqual(backbone.out_indices, (0, 2)) self.assertIsInstance(backbone, ResNetBackbone) @@ -172,7 +179,7 @@ def test_load_backbone_from_checkpoint(self): Test that load_backbone correctly loads a backbone from a checkpoint. """ config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, [4]) self.assertEqual(backbone.out_features, ["stage4"]) self.assertIsInstance(backbone, ResNetBackbone) @@ -181,7 +188,7 @@ def test_load_backbone_from_checkpoint(self): backbone="resnet18", use_timm_backbone=True, ) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) # We can't know ahead of time the exact output features and indices, or the layer names before # creating the timm model, so it defaults to the last layer (-1,) and has a different layer name self.assertEqual(backbone.out_indices, (-1,)) @@ -195,12 +202,12 @@ def test_load_backbone_backbone_kwargs(self): Test that load_backbone correctly configures the loaded backbone with the provided kwargs. """ config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)}) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, (0, 1)) self.assertIsInstance(backbone, TimmBackbone) config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)}) - backbone = load_backbone(config) + backbone = AutoBackbone.from_config(config.backbone_config) self.assertEqual(backbone.out_indices, (0, 2)) self.assertIsInstance(backbone, ResNetBackbone) @@ -223,7 +230,7 @@ def test_load_backbone_in_new_model(self): class NewModel(BertPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = load_backbone(config) + self.backbone = AutoBackbone.from_config(config.backbone_config) self.layer_0 = torch.nn.Linear(config.hidden_size, config.hidden_size) self.layer_1 = torch.nn.Linear(config.hidden_size, config.hidden_size)