Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/main_classes/backbones.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ A backbone is a model used for feature extraction for higher level computer visi
* [`~backbone_utils.BackboneMixin`] enables initializing a backbone from Transformers or [timm](https://hf.co/docs/timm/index) and includes functions for returning the output features and indices.
* [`~backbone_utils.BackboneConfigMixin`] sets the output features and indices of the backbone configuration.

[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmBackbone`] and [`TimmBackboneConfig`] classes.
[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmWrapperBackboneModel`] and [`TimmWrapperConfig`] classes.

Backbones are supported for the following models:

Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/timm_wrapper.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] TimmWrapperImageProcessor
- preprocess

## TimmWrapperBackboneModel

[[autodoc]] TimmWrapperBackboneModel
- forward

## TimmWrapperModel

[[autodoc]] TimmWrapperModel
Expand Down
40 changes: 17 additions & 23 deletions src/transformers/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,22 @@ def consolidate_backbone_kwargs_to_config(
and backbone_config is None
and not backbone_kwargs
):
backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **timm_default_kwargs)
backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **timm_default_kwargs)
elif backbone is not None and backbone_config is None:
if repo_exists(backbone):
config_dict, _ = PreTrainedConfig.get_config_dict(backbone)
config_class = CONFIG_MAPPING[config_dict["model_type"]]
config_dict.update(backbone_kwargs)
backbone_config = config_class(**config_dict)
else:
backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **backbone_kwargs)
# Move timm-args inside `model_args` to support loading from TimmBackboneConfig
if "model_args" not in backbone_kwargs:
backbone_kwargs["model_args"] = {
"in_chans": backbone_kwargs.pop("num_channels", 3),
"features_only": backbone_kwargs.pop("features_only", True),
"output_stride": backbone_kwargs.pop("output_stride", None),
}
backbone_config = CONFIG_MAPPING["timm_wrapper"](backbone, **backbone_kwargs)
elif backbone_config is None and default_config_type is not None:
logger.info(
f"`backbone_config` is `None`. Initializing the config with the default `{default_config_type}` vision config."
Expand All @@ -314,28 +321,15 @@ def consolidate_backbone_kwargs_to_config(
backbone_config = CONFIG_MAPPING[default_config_type](**default_config_kwargs)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
if backbone_model_type == "timm_backbone":
backbone_model_type = "timm_wrapper"
# Move timm-args inside `model_args`
backbone_config["model_args"] = {
"in_chans": backbone_config.pop("num_channels", 3),
"features_only": backbone_config.pop("features_only", True),
"output_stride": backbone_config.pop("output_stride", None),
}
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)

return backbone_config, kwargs


def load_backbone(config):
"""
Loads the backbone model from a config object.

If the config is from the backbone model itself, then we return a backbone model with randomly initialized
weights.

If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights
if specified.
"""
from transformers import AutoBackbone

backbone_config = getattr(config, "backbone_config", None)

if backbone_config is None:
backbone = AutoBackbone.from_config(config=config)
else:
backbone = AutoBackbone.from_config(config=backbone_config)
return backbone
42 changes: 22 additions & 20 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
"qwen3_omni_moe": "qwen2_moe",
"qwen3_omni_moe_thinker": "qwen2_moe",
"qwen3_next": "qwen2_moe",
"qwen3_5_moe": "qwen2_moe",
"hunyuan_v1_moe": "qwen2_moe",
"flex_olmo": "qwen2_moe",
"olmoe": "qwen2_moe",
Expand All @@ -67,18 +66,26 @@

def _build_checkpoint_conversion_mapping():
mapping = {
"paligemma": [
WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"),
WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"),
],
"llava": [
WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"),
WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"),
],
"qwen2_vl": [
WeightRenaming(
source_patterns=r"(^|\.)model(?!\.(language_model|visual))", target_patterns="model.language_model"
),
],
"qwen3_5_text": [
WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
],
"t5gemma2": [
WeightRenaming(r"(?<!vision_model\.)encoder.embed_tokens.", "encoder.text_model.embed_tokens."),
WeightRenaming(r"(?<!vision_model\.)encoder.norm.", "encoder.text_model.norm."),
WeightRenaming(r"(?<!vision_model\.)encoder.layers.", "encoder.text_model.layers."),
],
"t5gemma2_encoder": [
WeightRenaming("^embed_tokens.", "text_model.embed_tokens."),
WeightRenaming("^norm.", "text_model.norm."),
WeightRenaming("^layers.", "text_model.layers."),
WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)embed_tokens.", "text_model.embed_tokens."),
WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)norm.", "text_model.norm."),
WeightRenaming(r"(?<!vision_model.encoder\.)(?<!decoder\.)(?<!text_model\.)layers.", "text_model.layers."),
],
"gpt_oss": [
# NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint.
Expand Down Expand Up @@ -295,13 +302,13 @@ def _build_checkpoint_conversion_mapping():
operations=[MergeModulelist(dim=0)],
),
],
"timm_wrapper": [
# Simply add the prefix `timm_model`
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
Comment on lines -298 to -300
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_model_prefix can do pretty well and doesn't have false-positive matches when reverse mapping

"timm_backbone": [
# For BC with backbone model after deprecating `TimmBackbone` model class
# TODO: the conversion mapping doesn't work well with literal dots (r'\.') in source
WeightRenaming(
source_patterns=r"(.+)",
target_patterns=r"timm_model.\1",
)
source_patterns=r"\._backbone\.",
target_patterns=r".timm_model.",
),
],
"legacy": [
WeightRenaming(
Expand Down Expand Up @@ -372,7 +379,6 @@ def register_checkpoint_conversion_mapping(
VLMS = [
"aria",
"ayavision",
"colpali",
"emu3",
"fuyu",
"gotocr2",
Expand All @@ -381,12 +387,9 @@ def register_checkpoint_conversion_mapping(
"llava", # all llava prefixed models fall under this check
"mistral3",
"mllama",
"paligemma",
"shieldgemma2",
"qwen2vl",
"qwen2_5_vl",
"videollava",
"vipllava",
"sam3_video",
"sam3",
"sam3_tracker",
Expand Down Expand Up @@ -422,7 +425,6 @@ def get_model_conversion_mapping(
for k, v in model._checkpoint_conversion_mapping.items()
]

# TODO: should be checked recursively on submodels!!
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed it for timm, so we can define it once in above mapping and re-use in all models where Timm is a backbone

model_type = getattr(model.config, "model_type", None)
if model_type is not None:
model_specific_conversions = get_checkpoint_conversion_mapping(model_type)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def process_target_pattern(pattern: str) -> tuple[str, str | None]:
# Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of
# Qwen2.5, Sam3, Ernie4.5 VL MoE!
pattern = re.sub(r"\(\?.+\)", "", pattern)
# Remove the backslash for literal dots
pattern = pattern.replace(r"\.", ".")
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
capturing_group_match = re.search(r"\(.+?\)", pattern)
captured_group = None
Expand Down Expand Up @@ -1257,10 +1259,8 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
# In this case, the model was not created with `from_pretrained` -> let's check if it's in the hardcoded
# mappings, and recreate the mapping from there if it is
if weight_conversions is None:
from .conversion_mapping import get_model_conversion_mapping

# Do not resave with the legacy renaming, if present
weight_conversions = get_model_conversion_mapping(model, add_legacy=False)
weight_conversions = model.get_weight_conversions_recursively(add_legacy=False)
weight_conversions = weight_conversions if len(weight_conversions) > 0 else None

# We did not find any operations to perform -> quick escape
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ..conversion_mapping import (
_MODEL_TO_CONVERSION_PATTERN,
get_checkpoint_conversion_mapping,
get_model_conversion_mapping,
)
from ..core_model_loading import (
Concatenate,
Expand Down Expand Up @@ -519,7 +518,7 @@ def load_adapter(
**load_config.download_kwargs,
)

weight_conversions = get_model_conversion_mapping(self)
weight_conversions = self.get_weight_conversions_recursively()
peft_config = convert_peft_config_for_transformers(peft_config, model=self, conversions=weight_conversions)

if hasattr(peft_config, "inference_mode"):
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4066,8 +4066,8 @@ def from_pretrained(
# instantiated model, as the flags can be modified by instances sometimes)
dtype_plan = model._get_dtype_plan(dtype)

# Obtain the weight conversion mapping for this model if any are registered
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
# Obtain the weight conversion mapping for this model if any are registered and appy to all submodels recursively
weight_conversions = model.get_weight_conversions_recursively(key_mapping, hf_quantizer)

if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
Expand Down Expand Up @@ -4265,6 +4265,20 @@ def _finalize_model_loading(

return loading_info

def get_weight_conversions_recursively(self, key_mapping=None, hf_quantizer=None, add_legacy=True):
conversions = []
conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer, add_legacy))

for submodule in self.children():
if (
submodule is not self
and isinstance(submodule, PreTrainedModel)
and submodule.config.__class__ != self.config.__class__
):
conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer, add_legacy))
conversions.extend(submodule.get_weight_conversions_recursively(key_mapping, hf_quantizer, add_legacy))
return conversions

def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = {".".join(key.split(".")[:-1]) for key in names}

Expand Down
71 changes: 30 additions & 41 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from collections.abc import Iterator
from typing import Any, TypeVar

from huggingface_hub import repo_exists

from ...configuration_utils import PreTrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import (
Expand Down Expand Up @@ -247,8 +245,38 @@ def _prepare_config_for_auto_class(cls, config: PreTrainedConfig) -> PreTrainedC
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
return config

@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_wrapper import TimmWrapperConfig

if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")

# Users can't pass `config` and `kwargs`, choose only one!
config = kwargs.pop("config", None)
if config is None:
config = TimmWrapperConfig(
architecture=pretrained_model_name_or_path,
do_pooling=False,
out_indices=kwargs.pop("out_indices", (-1,)),
model_args={
"in_chans": kwargs.pop("num_channels", 3),
"features_only": kwargs.pop("features_only", True),
},
)

# Always load a pretrained model when `from_pretrained` is called
kwargs.pop("use_pretrained_backbone", None)
return cls.from_config(config, pretrained=True, **kwargs)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args, **kwargs):
# Early exit for `timm` models, they aren't hosted on the hub usually
use_timm_backbone = kwargs.pop("use_timm_backbone", None)
Comment on lines +275 to +276
Copy link
Copy Markdown
Member Author

@zucchini-nlp zucchini-nlp Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep it actually and use only when Autobackbone.from_pretrained(). We don't call from_pretrained anywhere across repo so it will be used only by users
Then we can delete _BaseAutoBackboneClass

if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

config = kwargs.pop("config", None)
trust_remote_code = kwargs.get("trust_remote_code")
kwargs["_from_auto"] = True
Expand Down Expand Up @@ -399,45 +427,6 @@ def register(cls, config_class, model_class, exist_ok=False) -> None:
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)


class _BaseAutoBackboneClass(_BaseAutoModelClass):
# Base class for auto backbone models.
_model_mapping = None

@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig

config = kwargs.pop("config", TimmBackboneConfig())

if kwargs.get("out_features") is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")

if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")

num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
out_indices=out_indices,
)
# Always load a pretrained model when `from_pretrained` is called
kwargs.pop("use_pretrained_backbone", None)
return super().from_config(config, pretrained=True, **kwargs)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
kwargs.pop("use_timm_backbone", None)
if not repo_exists(pretrained_model_name_or_path):
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)


def insert_head_doc(docstring, head_doc: str = ""):
if len(head_doc) > 0:
return docstring.replace(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesfm", "TimesFmConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("timm_backbone", "TimmBackboneConfig"), # for BC
Copy link
Copy Markdown
Member Author

@zucchini-nlp zucchini-nlp Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should map any timm_backbone to the new model class when loading, so we don't log deprecation warnings. Mapping happens in the auto-modeling file

("timm_wrapper", "TimmWrapperConfig"),
("trocr", "TrOCRConfig"),
("tvp", "TvpConfig"),
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
_BaseAutoModelClass,
_LazyAutoMapping,
auto_class_update,
Expand Down Expand Up @@ -427,7 +426,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesfm", "TimesFmModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
("tvp", "TvpModel"),
("udop", "UdopModel"),
Expand Down Expand Up @@ -778,7 +776,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("swinv2", "Swinv2Model"),
("table-transformer", "TableTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
("videomae", "VideoMAEModel"),
("vit", "ViTModel"),
Expand Down Expand Up @@ -1647,7 +1644,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("textnet", "TextNetBackbone"),
("timm_backbone", "TimmBackbone"),
("timm_backbone", "TimmWrapperBackboneModel"), # for BC
("timm_wrapper", "TimmWrapperBackboneModel"),
("vitdet", "VitDetBackbone"),
("vitpose_backbone", "VitPoseBackbone"),
]
Expand Down Expand Up @@ -2161,7 +2159,7 @@ class AutoModelForTextToWaveform(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING


class AutoBackbone(_BaseAutoBackboneClass):
class AutoBackbone(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING


Expand Down
Loading