-
Notifications
You must be signed in to change notification settings - Fork 32.9k
Timm unification continued #44252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Timm unification continued #44252
Changes from all commits
5579b0c
a5dbd5f
5c7eb10
4303988
217070e
55e9723
f7a0fec
abf4993
8a6e379
04e0d8a
ba56a16
a829360
100c489
854f2fb
c6476bc
3a790ec
ab6def3
375eb3e
1d94e42
2f36718
882b555
9794fcf
f13720c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,6 @@ | |
| "qwen3_omni_moe": "qwen2_moe", | ||
| "qwen3_omni_moe_thinker": "qwen2_moe", | ||
| "qwen3_next": "qwen2_moe", | ||
| "qwen3_5_moe": "qwen2_moe", | ||
| "hunyuan_v1_moe": "qwen2_moe", | ||
| "flex_olmo": "qwen2_moe", | ||
| "olmoe": "qwen2_moe", | ||
|
|
@@ -67,18 +66,26 @@ | |
|
|
||
| def _build_checkpoint_conversion_mapping(): | ||
| mapping = { | ||
| "paligemma": [ | ||
| WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), | ||
| WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), | ||
| ], | ||
| "llava": [ | ||
| WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), | ||
| WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), | ||
| ], | ||
| "qwen2_vl": [ | ||
| WeightRenaming( | ||
| source_patterns=r"(^|\.)model(?!\.(language_model|visual))", target_patterns="model.language_model" | ||
| ), | ||
| ], | ||
| "qwen3_5_text": [ | ||
| WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), | ||
| ], | ||
| "t5gemma2": [ | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.embed_tokens.", "encoder.text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.norm.", "encoder.text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.layers.", "encoder.text_model.layers."), | ||
| ], | ||
| "t5gemma2_encoder": [ | ||
| WeightRenaming("^embed_tokens.", "text_model.embed_tokens."), | ||
| WeightRenaming("^norm.", "text_model.norm."), | ||
| WeightRenaming("^layers.", "text_model.layers."), | ||
| WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)embed_tokens.", "text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)norm.", "text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model.encoder\.)(?<!decoder\.)(?<!text_model\.)layers.", "text_model.layers."), | ||
| ], | ||
| "gpt_oss": [ | ||
| # NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint. | ||
|
|
@@ -295,13 +302,13 @@ def _build_checkpoint_conversion_mapping(): | |
| operations=[MergeModulelist(dim=0)], | ||
| ), | ||
| ], | ||
| "timm_wrapper": [ | ||
| # Simply add the prefix `timm_model` | ||
| # TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming | ||
| "timm_backbone": [ | ||
| # For BC with backbone model after deprecating `TimmBackbone` model class | ||
| # TODO: the conversion mapping doesn't work well with literal dots (r'\.') in source | ||
| WeightRenaming( | ||
| source_patterns=r"(.+)", | ||
| target_patterns=r"timm_model.\1", | ||
| ) | ||
| source_patterns=r"\._backbone\.", | ||
| target_patterns=r".timm_model.", | ||
| ), | ||
| ], | ||
| "legacy": [ | ||
| WeightRenaming( | ||
|
|
@@ -372,7 +379,6 @@ def register_checkpoint_conversion_mapping( | |
| VLMS = [ | ||
| "aria", | ||
| "ayavision", | ||
| "colpali", | ||
| "emu3", | ||
| "fuyu", | ||
| "gotocr2", | ||
|
|
@@ -381,12 +387,9 @@ def register_checkpoint_conversion_mapping( | |
| "llava", # all llava prefixed models fall under this check | ||
| "mistral3", | ||
| "mllama", | ||
| "paligemma", | ||
| "shieldgemma2", | ||
| "qwen2vl", | ||
| "qwen2_5_vl", | ||
| "videollava", | ||
| "vipllava", | ||
| "sam3_video", | ||
| "sam3", | ||
| "sam3_tracker", | ||
|
|
@@ -422,7 +425,6 @@ def get_model_conversion_mapping( | |
| for k, v in model._checkpoint_conversion_mapping.items() | ||
| ] | ||
|
|
||
| # TODO: should be checked recursively on submodels!! | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needed it for timm, so we can define it once in above mapping and re-use in all models where Timm is a backbone |
||
| model_type = getattr(model.config, "model_type", None) | ||
| if model_type is not None: | ||
| model_specific_conversions = get_checkpoint_conversion_mapping(model_type) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,6 @@ | |
| from collections.abc import Iterator | ||
| from typing import Any, TypeVar | ||
|
|
||
| from huggingface_hub import repo_exists | ||
|
|
||
| from ...configuration_utils import PreTrainedConfig | ||
| from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code | ||
| from ...utils import ( | ||
|
|
@@ -247,8 +245,38 @@ def _prepare_config_for_auto_class(cls, config: PreTrainedConfig) -> PreTrainedC | |
| """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" | ||
| return config | ||
|
|
||
| @classmethod | ||
| def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
| requires_backends(cls, ["vision", "timm"]) | ||
| from ...models.timm_wrapper import TimmWrapperConfig | ||
|
|
||
| if kwargs.get("output_loading_info", False): | ||
| raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") | ||
|
|
||
| # Users can't pass `config` and `kwargs`, choose only one! | ||
| config = kwargs.pop("config", None) | ||
| if config is None: | ||
| config = TimmWrapperConfig( | ||
| architecture=pretrained_model_name_or_path, | ||
| do_pooling=False, | ||
| out_indices=kwargs.pop("out_indices", (-1,)), | ||
| model_args={ | ||
| "in_chans": kwargs.pop("num_channels", 3), | ||
| "features_only": kwargs.pop("features_only", True), | ||
| }, | ||
| ) | ||
|
|
||
| # Always load a pretrained model when `from_pretrained` is called | ||
| kwargs.pop("use_pretrained_backbone", None) | ||
| return cls.from_config(config, pretrained=True, **kwargs) | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args, **kwargs): | ||
| # Early exit for `timm` models, they aren't hosted on the hub usually | ||
| use_timm_backbone = kwargs.pop("use_timm_backbone", None) | ||
|
Comment on lines
+275
to
+276
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's keep it actually and use only when |
||
| if use_timm_backbone: | ||
| return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
|
|
||
| config = kwargs.pop("config", None) | ||
| trust_remote_code = kwargs.get("trust_remote_code") | ||
| kwargs["_from_auto"] = True | ||
|
|
@@ -399,45 +427,6 @@ def register(cls, config_class, model_class, exist_ok=False) -> None: | |
| cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) | ||
|
|
||
|
|
||
| class _BaseAutoBackboneClass(_BaseAutoModelClass): | ||
| # Base class for auto backbone models. | ||
| _model_mapping = None | ||
|
|
||
| @classmethod | ||
| def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
| requires_backends(cls, ["vision", "timm"]) | ||
| from ...models.timm_backbone import TimmBackboneConfig | ||
|
|
||
| config = kwargs.pop("config", TimmBackboneConfig()) | ||
|
|
||
| if kwargs.get("out_features") is not None: | ||
| raise ValueError("Cannot specify `out_features` for timm backbones") | ||
|
|
||
| if kwargs.get("output_loading_info", False): | ||
| raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") | ||
|
|
||
| num_channels = kwargs.pop("num_channels", config.num_channels) | ||
| features_only = kwargs.pop("features_only", config.features_only) | ||
| out_indices = kwargs.pop("out_indices", config.out_indices) | ||
| config = TimmBackboneConfig( | ||
| backbone=pretrained_model_name_or_path, | ||
| num_channels=num_channels, | ||
| features_only=features_only, | ||
| out_indices=out_indices, | ||
| ) | ||
| # Always load a pretrained model when `from_pretrained` is called | ||
| kwargs.pop("use_pretrained_backbone", None) | ||
| return super().from_config(config, pretrained=True, **kwargs) | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
| kwargs.pop("use_timm_backbone", None) | ||
| if not repo_exists(pretrained_model_name_or_path): | ||
| return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
|
|
||
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
|
|
||
|
|
||
| def insert_head_doc(docstring, head_doc: str = ""): | ||
| if len(head_doc) > 0: | ||
| return docstring.replace( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -445,7 +445,7 @@ | |
| ("time_series_transformer", "TimeSeriesTransformerConfig"), | ||
| ("timesfm", "TimesFmConfig"), | ||
| ("timesformer", "TimesformerConfig"), | ||
| ("timm_backbone", "TimmBackboneConfig"), | ||
| ("timm_backbone", "TimmBackboneConfig"), # for BC | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should map any |
||
| ("timm_wrapper", "TimmWrapperConfig"), | ||
| ("trocr", "TrOCRConfig"), | ||
| ("tvp", "TvpConfig"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_model_prefixcan do pretty well and doesn't have false-positive matches when reverse mapping