diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0d2b752d5ad9..56fd2c88ecf4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -703,6 +703,8 @@ title: Swin2SR - local: model_doc/table-transformer title: Table Transformer + - local: model_doc/timm_wrapper + title: Timm Wrapper - local: model_doc/upernet title: UperNet - local: model_doc/van diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8a9ccf45b69c..23c73ee9c20a 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow. | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | +| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ | | [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ | | [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ | | [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/timm_wrapper.md b/docs/source/en/model_doc/timm_wrapper.md new file mode 100644 index 000000000000..5af3d51746c3 --- /dev/null +++ b/docs/source/en/model_doc/timm_wrapper.md @@ -0,0 +1,67 @@ + + +# TimmWrapper + +## Overview + +Helper class to enable loading timm models to be used with the transformers library and its autoclasses. + +```python +>>> import torch +>>> from PIL import Image +>>> from urllib.request import urlopen +>>> from transformers import AutoModelForImageClassification, AutoImageProcessor + +>>> # Load image +>>> image = Image.open(urlopen( +... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' +... )) + +>>> # Load model and image processor +>>> checkpoint = "timm/resnet50.a1_in1k" +>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) +>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() + +>>> # Preprocess image +>>> inputs = image_processor(image) + +>>> # Forward pass +>>> with torch.no_grad(): +... logits = model(**inputs).logits + +>>> # Get top 5 predictions +>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) +``` + +## TimmWrapperConfig + +[[autodoc]] TimmWrapperConfig + +## TimmWrapperImageProcessor + +[[autodoc]] TimmWrapperImageProcessor + - preprocess + +## TimmWrapperModel + +[[autodoc]] TimmWrapperModel + - forward + +## TimmWrapperForImageClassification + +[[autodoc]] TimmWrapperForImageClassification + - forward diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index aa1cd089ef5c..d7b26ee58cc3 100755 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -42,6 +42,7 @@ AutoImageProcessor, AutoModelForImageClassification, HfArgumentParser, + TimmWrapperImageProcessor, Trainer, TrainingArguments, set_seed, @@ -329,31 +330,36 @@ def compute_metrics(p): ) # Define torchvision transforms to be applied to each image. - if "shortest_edge" in image_processor.size: - size = image_processor.size["shortest_edge"] + if isinstance(image_processor, TimmWrapperImageProcessor): + _train_transforms = image_processor.train_transforms + _val_transforms = image_processor.val_transforms else: - size = (image_processor.size["height"], image_processor.size["width"]) - normalize = ( - Normalize(mean=image_processor.image_mean, std=image_processor.image_std) - if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") - else Lambda(lambda x: x) - ) - _train_transforms = Compose( - [ - RandomResizedCrop(size), - RandomHorizontalFlip(), - ToTensor(), - normalize, - ] - ) - _val_transforms = Compose( - [ - Resize(size), - CenterCrop(size), - ToTensor(), - normalize, - ] - ) + if "shortest_edge" in image_processor.size: + size = image_processor.size["shortest_edge"] + else: + size = (image_processor.size["height"], image_processor.size["width"]) + + # Create normalization transform + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"): + normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + else: + normalize = Lambda(lambda x: x) + _train_transforms = Compose( + [ + RandomResizedCrop(size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + _val_transforms = Compose( + [ + Resize(size), + CenterCrop(size), + ToTensor(), + normalize, + ] + ) def train_transforms(example_batch): """Apply _train_transforms across a batch.""" diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e1ca19568073..03637b604e9d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -776,6 +776,7 @@ "models.time_series_transformer": ["TimeSeriesTransformerConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], + "models.timm_wrapper": ["TimmWrapperConfig"], "models.trocr": [ "TrOCRConfig", "TrOCRProcessor", @@ -1265,6 +1266,18 @@ _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") +try: + if not is_torchvision_available() and not is_timm_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_timm_and_torchvision_objects + + _import_structure["utils.dummy_timm_and_torchvision_objects"] = [ + name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_") + ] +else: + _import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"]) + # PyTorch-backed objects try: if not is_torch_available(): @@ -3507,6 +3520,9 @@ ] ) _import_structure["models.timm_backbone"].extend(["TimmBackbone"]) + _import_structure["models.timm_wrapper"].extend( + ["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"] + ) _import_structure["models.trocr"].extend( [ "TrOCRForCausalLM", @@ -5703,6 +5719,7 @@ TimesformerConfig, ) from .models.timm_backbone import TimmBackboneConfig + from .models.timm_wrapper import TimmWrapperConfig from .models.trocr import ( TrOCRConfig, TrOCRProcessor, @@ -6195,6 +6212,14 @@ from .models.rt_detr import RTDetrImageProcessorFast from .models.vit import ViTImageProcessorFast + try: + if not is_torchvision_available() and not is_timm_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_timm_and_torchvision_objects import * + else: + from .models.timm_wrapper import TimmWrapperImageProcessor + # Modeling try: if not is_torch_available(): @@ -7991,6 +8016,11 @@ TimesformerPreTrainedModel, ) from .models.timm_backbone import TimmBackbone + from .models.timm_wrapper import ( + TimmWrapperForImageClassification, + TimmWrapperModel, + TimmWrapperPreTrainedModel, + ) from .models.trocr import ( TrOCRForCausalLM, TrOCRPreTrainedModel, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index e49eab86b4e1..a04b7bd6aa1b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -37,6 +37,7 @@ download_url, extract_commit_hash, is_remote_url, + is_timm_config_dict, is_torch_available, logging, ) @@ -702,6 +703,11 @@ def _get_config_dict( config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( config_dict["custom_pipelines"], pretrained_model_name_or_path ) + + # timm models are not saved with the model_type in the config file + if "model_type" not in config_dict and is_timm_config_dict(config_dict): + config_dict["model_type"] = "timm_wrapper" + return config_dict, kwargs @classmethod diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index e73d4a8a56f3..c5af652decf2 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -285,6 +285,8 @@ def get_image_processor_dict( subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. + image_processor_filename (`str`, *optional*, defaults to `"config.json"`): + The name of the file in the model directory to use for the image processor config. Returns: `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. @@ -298,6 +300,7 @@ def get_image_processor_dict( local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", "") + image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) @@ -324,7 +327,7 @@ def get_image_processor_dict( pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): - image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME) + image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename) if os.path.isfile(pretrained_model_name_or_path): resolved_image_processor_file = pretrained_model_name_or_path is_local = True @@ -332,7 +335,7 @@ def get_image_processor_dict( image_processor_file = pretrained_model_name_or_path resolved_image_processor_file = download_url(pretrained_model_name_or_path) else: - image_processor_file = IMAGE_PROCESSOR_NAME + image_processor_file = image_processor_filename try: # Load from local folder or from cache or download from model Hub and cache resolved_image_processor_file = cached_file( @@ -358,7 +361,7 @@ def get_image_processor_dict( f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a {IMAGE_PROCESSOR_NAME} file" + f" directory containing a {image_processor_filename} file" ) try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dae29111c8dc..f349847b1fd7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -503,7 +503,7 @@ def load_state_dict( # Check format of the archive with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() - if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: raise OSError( f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " "you save your model with the `save_pretrained` method." @@ -652,36 +652,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): - # Convert old format to new format if needed from a PyTorch state_dict - old_keys = [] - new_keys = [] - renamed_keys = {} - renamed_gamma = {} - renamed_beta = {} - warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - # We add only the first key as an example - new_key = key.replace("gamma", "weight") - renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma - if "beta" in key: - # We add only the first key as an example - new_key = key.replace("beta", "bias") - renamed_beta[key] = new_key if not renamed_beta else renamed_beta - if new_key: - old_keys.append(key) - new_keys.append(new_key) - renamed_keys = {**renamed_gamma, **renamed_beta} - if renamed_keys: - warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" - for old_key, new_key in renamed_keys.items(): - warning_msg += f"* `{old_key}` -> `{new_key}`\n" - warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." - logger.info_once(warning_msg) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() @@ -812,46 +782,7 @@ def _load_state_dict_into_meta_model( error_msgs = [] - old_keys = [] - new_keys = [] - renamed_gamma = {} - renamed_beta = {} is_quantized = hf_quantizer is not None - warning_msg = f"This model {type(model)}" - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - # We add only the first key as an example - new_key = key.replace("gamma", "weight") - renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma - if "beta" in key: - # We add only the first key as an example - new_key = key.replace("beta", "bias") - renamed_beta[key] = new_key if not renamed_beta else renamed_beta - - # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary. - if hasattr(nn.utils.parametrizations, "weight_norm"): - if "weight_g" in key: - new_key = key.replace("weight_g", "parametrizations.weight.original0") - if "weight_v" in key: - new_key = key.replace("weight_v", "parametrizations.weight.original1") - else: - if "parametrizations.weight.original0" in key: - new_key = key.replace("parametrizations.weight.original0", "weight_g") - if "parametrizations.weight.original1" in key: - new_key = key.replace("parametrizations.weight.original1", "weight_v") - if new_key: - old_keys.append(key) - new_keys.append(new_key) - renamed_keys = {**renamed_gamma, **renamed_beta} - if renamed_keys: - warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" - for old_key, new_key in renamed_keys.items(): - warning_msg += f"* `{old_key}` -> `{new_key}`\n" - warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." - logger.info_once(warning_msg) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") @@ -2888,6 +2819,11 @@ def save_pretrained( for ignore_key in self._keys_to_ignore_on_save: if ignore_key in state_dict.keys(): del state_dict[ignore_key] + + # Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model. + # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm) + state_dict = self._fix_state_dict_keys_on_save(state_dict) + if safe_serialization: # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving @@ -4010,7 +3946,10 @@ def from_pretrained( with safe_open(resolved_archive_file, framework="pt") as f: metadata = f.metadata() - if metadata.get("format") == "pt": + if metadata is None: + # Assume it's a pytorch checkpoint (introduced for timm checkpoints) + pass + elif metadata.get("format") == "pt": pass elif metadata.get("format") == "tf": from_tf = True @@ -4375,6 +4314,72 @@ def from_pretrained( return model + @staticmethod + def _fix_state_dict_key_on_load(key): + """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" + + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + + # to avoid logging parametrized weight norm renaming + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + return key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + return key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + return key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + return key.replace("parametrizations.weight.original1", "weight_v") + return key + + @classmethod + def _fix_state_dict_keys_on_load(cls, state_dict): + """Fixes state dict keys by replacing legacy parameter names with their modern equivalents. + Logs if any parameters have been renamed. + """ + + renamed_keys = {} + state_dict_keys = list(state_dict.keys()) + for key in state_dict_keys: + new_key = cls._fix_state_dict_key_on_load(key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + + # add it once for logging + if "gamma" in key and "gamma" not in renamed_keys: + renamed_keys["gamma"] = (key, new_key) + if "beta" in key and "beta" not in renamed_keys: + renamed_keys["beta"] = (key, new_key) + + if renamed_keys: + warning_msg = f"A pretrained model of type `{cls.__name__}` " + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.values(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) + + return state_dict + + @staticmethod + def _fix_state_dict_key_on_save(key): + """ + Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. + Do nothing by default, but can be overriden in particular models. + """ + return key + + def _fix_state_dict_keys_on_save(self, state_dict): + """ + Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. + Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. + """ + return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} + @classmethod def _load_pretrained_model( cls, @@ -4430,27 +4435,8 @@ def _load_pretrained_model( if hf_quantizer is not None: expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) - def _fix_key(key): - if "beta" in key: - return key.replace("beta", "bias") - if "gamma" in key: - return key.replace("gamma", "weight") - - # to avoid logging parametrized weight norm renaming - if hasattr(nn.utils.parametrizations, "weight_norm"): - if "weight_g" in key: - return key.replace("weight_g", "parametrizations.weight.original0") - if "weight_v" in key: - return key.replace("weight_v", "parametrizations.weight.original1") - else: - if "parametrizations.weight.original0" in key: - return key.replace("parametrizations.weight.original0", "weight_g") - if "parametrizations.weight.original1" in key: - return key.replace("parametrizations.weight.original1", "weight_v") - return key - original_loaded_keys = loaded_keys - loaded_keys = [_fix_key(key) for key in loaded_keys] + loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys] if len(prefix) > 0: has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) @@ -4615,23 +4601,23 @@ def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, + original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: + for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys): # If the checkpoint is sharded, we may not have the key here. if checkpoint_key not in state_dict: continue - model_key = checkpoint_key if remove_prefix_from_model: # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{checkpoint_key}" + model_key = f"{prefix}.{model_key}" elif add_prefix_to_model: # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(checkpoint_key.split(".")[1:]) + model_key = ".".join(model_key.split(".")[1:]) if ( model_key in model_state_dict @@ -4680,6 +4666,7 @@ def _find_mismatched_keys( mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, + loaded_keys, original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, @@ -4688,9 +4675,10 @@ def _find_mismatched_keys( # For GGUF models `state_dict` is never set to None as the state dict is always small if gguf_path: + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, - state_dict, + fixed_state_dict, start_prefix, expected_keys, device_map=device_map, @@ -4709,8 +4697,9 @@ def _find_mismatched_keys( assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix ) + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs = _load_state_dict_into_model( - model_to_load, state_dict, start_prefix, assign_to_params_buffers + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers ) else: @@ -4761,6 +4750,7 @@ def _find_mismatched_keys( mismatched_keys += _find_mismatched_keys( state_dict, model_state_dict, + loaded_keys, original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, @@ -4774,9 +4764,10 @@ def _find_mismatched_keys( model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) ) else: + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, - state_dict, + fixed_state_dict, start_prefix, expected_keys, device_map=device_map, @@ -4797,8 +4788,9 @@ def _find_mismatched_keys( assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix ) + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs += _load_state_dict_into_model( - model_to_load, state_dict, start_prefix, assign_to_params_buffers + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers ) # force memory release @@ -4930,9 +4922,10 @@ def _load_pretrained_model_low_mem( _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys + fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict) error_msgs = _load_state_dict_into_meta_model( model, - state_dict, + fixed_state_dict, start_prefix, expected_keys=expected_keys, hf_quantizer=hf_quantizer, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2d2a3b41d437..b634bc5e7635 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -247,6 +247,7 @@ time_series_transformer, timesformer, timm_backbone, + timm_wrapper, trocr, tvp, udop, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4ab6d3922826..050ad8fd7a7b 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -272,6 +272,7 @@ ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), + ("timm_wrapper", "TimmWrapperConfig"), ("trajectory_transformer", "TrajectoryTransformerConfig"), ("transfo-xl", "TransfoXLConfig"), ("trocr", "TrOCRConfig"), @@ -591,6 +592,7 @@ ("time_series_transformer", "Time Series Transformer"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), ("trajectory_transformer", "Trajectory Transformer"), ("transfo-xl", "Transformer-XL"), ("trocr", "TrOCR"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 11ae15ca461e..d2f68f4b0ce3 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -30,6 +30,8 @@ CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, + is_timm_config_dict, + is_timm_local_checkpoint, is_torchvision_available, is_vision_available, logging, @@ -135,6 +137,7 @@ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), ("table-transformer", ("DetrImageProcessor",)), ("timesformer", ("VideoMAEImageProcessor",)), + ("timm_wrapper", ("TimmWrapperImageProcessor",)), ("tvlt", ("TvltImageProcessor",)), ("tvp", ("TvpImageProcessor",)), ("udop", ("LayoutLMv3ImageProcessor",)), @@ -374,6 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. + image_processor_filename (`str`, *optional*, defaults to `"config.json"`): + The name of the file in the model directory to use for the image processor config. kwargs (`Dict[str, Any]`, *optional*): The values in kwargs of any keys which are image processor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is @@ -413,7 +418,37 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs["_from_auto"] = True - config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) + # Resolve the image processor config filename + if "image_processor_filename" in kwargs: + image_processor_filename = kwargs.pop("image_processor_filename") + elif is_timm_local_checkpoint(pretrained_model_name_or_path): + image_processor_filename = CONFIG_NAME + else: + image_processor_filename = IMAGE_PROCESSOR_NAME + + # Load the image processor config + try: + # Main path for all transformers models and local TimmWrapper checkpoints + config_dict, _ = ImageProcessingMixin.get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs + ) + except Exception as initial_exception: + # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json` + # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information + # except the model name, the only way to check if a remote checkpoint is a timm model is to try to + # load `config.json` and if it fails with some error, we raise the initial exception. + try: + config_dict, _ = ImageProcessingMixin.get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs + ) + except Exception: + raise initial_exception + + # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception, + # because only timm models have image processing in `config.json`. + if not is_timm_config_dict(config_dict): + raise initial_exception + image_processor_class = config_dict.get("image_processor_type", None) image_processor_auto_map = None if "AutoImageProcessor" in config_dict.get("auto_map", {}): diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2c519a7dc42c..936437530af3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -251,6 +251,7 @@ ("time_series_transformer", "TimeSeriesTransformerModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), ("trajectory_transformer", "TrajectoryTransformerModel"), ("transfo-xl", "TransfoXLModel"), ("tvlt", "TvltModel"), @@ -599,6 +600,7 @@ ("table-transformer", "TableTransformerModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), ("van", "VanModel"), ("videomae", "VideoMAEModel"), ("vit", "ViTModel"), @@ -683,6 +685,7 @@ ("swiftformer", "SwiftFormerForImageClassification"), ("swin", "SwinForImageClassification"), ("swinv2", "Swinv2ForImageClassification"), + ("timm_wrapper", "TimmWrapperForImageClassification"), ("van", "VanForImageClassification"), ("vit", "ViTForImageClassification"), ("vit_hybrid", "ViTHybridForImageClassification"), diff --git a/src/transformers/models/timm_wrapper/__init__.py b/src/transformers/models/timm_wrapper/__init__.py new file mode 100644 index 000000000000..9fbc4150412a --- /dev/null +++ b/src/transformers/models/timm_wrapper/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timm_wrapper import * + from .modeling_timm_wrapper import * + from .processing_timm_wrapper import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py new file mode 100644 index 000000000000..691a2b2b76ec --- /dev/null +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for TimmWrapper models""" + +from typing import Any, Dict + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimmWrapperConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. + + It is used to instantiate a timm model according to the specified arguments, defining the model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + do_pooling (`bool`, *optional*, defaults to `True`): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + + Example: + ```python + >>> from transformers import TimmWrapperModel + + >>> # Initializing a timm model + >>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k") + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "timm_wrapper" + + def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs): + self.initializer_range = initializer_range + self.do_pooling = do_pooling + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. + # We are removing these attributes in order to have the native `transformers` num_labels attribute in config + # and to avoid duplicate attributes + + num_labels_in_kwargs = kwargs.pop("num_labels", None) + num_labels_in_dict = config_dict.pop("num_classes", None) + + # passed num_labels has priority over num_classes in config_dict + kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + + # pop num_classes from "pretrained_cfg", + # it is not necessary to have it, only root one is used in timm + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) + + return super().from_dict(config_dict, **kwargs) + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + output["num_classes"] = self.num_labels + return output + + +__all__ = ["TimmWrapperConfig"] diff --git a/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py b/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py new file mode 100644 index 000000000000..02075a50fb26 --- /dev/null +++ b/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import to_pil_image +from ...image_utils import ImageInput, make_list_of_images +from ...utils import TensorType, logging, requires_backends +from ...utils.import_utils import is_timm_available, is_torch_available + + +if is_timm_available(): + import timm + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class TimmWrapperImageProcessor(BaseImageProcessor): + """ + Wrapper class for timm models to be used within transformers. + + Args: + pretrained_cfg (`Dict[str, Any]`): + The configuration of the pretrained model used to resolve evaluation and + training transforms. + architecture (`Optional[str]`, *optional*): + Name of the architecture of the model. + """ + + main_input_name = "pixel_values" + + def __init__( + self, + pretrained_cfg: Dict[str, Any], + architecture: Optional[str] = None, + **kwargs, + ): + requires_backends(self, "timm") + super().__init__(architecture=architecture) + + self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False) + self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False) + + # useful for training, see examples/pytorch/image-classification/run_image_classification.py + self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True) + + # If `ToTensor` is in the transforms, then the input should be numpy array or PIL image. + # Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used + # which can handle both numpy arrays / PIL images and tensors. + self._not_supports_tensor_input = any( + transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + """ + output = super().to_dict() + output.pop("train_transforms", None) + output.pop("val_transforms", None) + output.pop("_not_supports_tensor_input", None) + return output + + @classmethod + def get_image_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Get the image processor dict for the model. + """ + image_processor_filename = kwargs.pop("image_processor_filename", "config.json") + return super().get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs + ) + + def preprocess( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = "pt", + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + """ + if return_tensors != "pt": + raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}") + + if self._not_supports_tensor_input and isinstance(images, torch.Tensor): + images = images.cpu().numpy() + + # If the input is a torch tensor, then no conversion is needed + # Otherwise, we need to pass in a list of PIL images + if isinstance(images, torch.Tensor): + images = self.val_transforms(images) + # Add batch dimension if a single image + images = images.unsqueeze(0) if images.ndim == 3 else images + else: + images = make_list_of_images(images) + images = [to_pil_image(image) for image in images] + images = torch.stack([self.val_transforms(image) for image in images]) + + return BatchFeature({"pixel_values": images}, tensor_type=return_tensors) + + def save_pretrained(self, *args, **kwargs): + # disable it to make checkpoint the same as in `timm` library. + logger.warning_once( + "The `save_pretrained` method is disabled for TimmWrapperImageProcessor. " + "The image processor configuration is saved directly in `config.json` when " + "`save_pretrained` is called for saving the model." + ) + + +__all__ = ["TimmWrapperImageProcessor"] diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py new file mode 100644 index 000000000000..dfb14dfccec4 --- /dev/null +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -0,0 +1,363 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + is_timm_available, + replace_return_docstrings, + requires_backends, +) +from .configuration_timm_wrapper import TimmWrapperConfig + + +if is_timm_available(): + import timm + + +@dataclass +class TimmWrapperModelOutput(ModelOutput): + """ + Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output, + and optional hidden states. + + Args: + last_hidden_state (`torch.FloatTensor`): + The last hidden state of the model, output before applying the classification head. + pooler_output (`torch.FloatTensor`, *optional*): + The pooled output derived from the last hidden state, if applicable. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers. + Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`. + attentions (`tuple(torch.FloatTensor)`, *optional*): + A tuple containing the intermediate attention weights of the model at the output of each layer. + Returned if `output_attentions=True` is set or if `config.output_attentions=True`. + Note: Currently, Timm models do not support attentions output. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +TIMM_WRAPPER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + **kwargs: + Additional keyword arguments passed along to the `timm` model forward. +""" + + +class TimmWrapperPreTrainedModel(PreTrainedModel): + main_input_name = "pixel_values" + config_class = TimmWrapperConfig + _no_split_modules = [] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision", "timm"]) + super().__init__(*args, **kwargs) + + @staticmethod + def _fix_state_dict_key_on_load(key): + """ + Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. + We don't want this behavior for timm wrapped models. Instead, this method adds a + "timm_model." prefix to enable loading official timm Hub checkpoints. + """ + if "timm_model." not in key: + return f"timm_model.{key}" + return key + + def _fix_state_dict_key_on_save(self, key): + """ + Overrides original method to remove "timm_model." prefix from state_dict keys. + Makes the saved checkpoint compatible with the `timm` library. + """ + return key.replace("timm_model.", "") + + def load_state_dict(self, state_dict, *args, **kwargs): + """ + Override original method to fix state_dict keys on load for cases when weights are loaded + without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint). + """ + state_dict = self._fix_state_dict_keys_on_load(state_dict) + return super().load_state_dict(state_dict, *args, **kwargs) + + def _init_weights(self, module): + """ + Initialize weights function to properly initialize Linear layer weights. + Since model architectures may vary, we assume only the classifier requires + initialization, while all other weights should be loaded from the checkpoint. + """ + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +class TimmWrapperModel(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + # using num_classes=0 to avoid creating classification head + self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0) + self.post_init() + + @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, List[int]]] = None, + return_dict: Optional[bool] = None, + do_pooling: Optional[bool] = None, + **kwargs, + ) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]: + r""" + do_pooling (`bool`, *optional*): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the + `do_pooling` value from the config is used. + + Returns: + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModel, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModel.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # Get pooled output + >>> pooled_output = outputs.pooler_output + + >>> # Get last hidden state + >>> last_hidden_state = outputs.last_hidden_state + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + else: + last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs) + hidden_states = None + + if do_pooling: + # classification head is not created, applying pooling only + pooler_output = self.timm_model.forward_head(last_hidden_state) + else: + pooler_output = None + + if not return_dict: + outputs = (last_hidden_state, pooler_output, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return TimmWrapperModelOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=hidden_states, + ) + + +class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers for image classification. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + + if config.num_labels == 0: + raise ValueError( + "You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. " + "Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, " + "or use `TimmWrapperModel` for feature extraction." + ) + + self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels) + self.num_labels = config.num_labels + self.post_init() + + @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, List[int]]] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModelForImageClassification, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # Get top 5 predictions + >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + logits = self.timm_model.forward_head(last_hidden_state) + else: + logits = self.timm_model(pixel_values, **kwargs) + hidden_states = None + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + outputs = (loss, logits, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) + + +__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"] diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index f7e962bec346..5aa4f7e7346d 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -55,6 +55,8 @@ is_tensor, is_tf_symbolic_tensor, is_tf_tensor, + is_timm_config_dict, + is_timm_local_checkpoint, is_torch_device, is_torch_dtype, is_torch_tensor, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1238f058783c..031f3fd362c0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8973,6 +8973,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class TimmWrapperForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimmWrapperModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimmWrapperPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TrOCRForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_timm_and_torchvision_objects.py b/src/transformers/utils/dummy_timm_and_torchvision_objects.py new file mode 100644 index 000000000000..8b67b5dac58d --- /dev/null +++ b/src/transformers/utils/dummy_timm_and_torchvision_objects.py @@ -0,0 +1,9 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class TimmWrapperImageProcessor(metaclass=DummyObject): + _backends = ["timm", "torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["timm", "torchvision"]) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 26ec82b20fd4..a997da79e841 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -16,6 +16,8 @@ """ import inspect +import json +import os import tempfile import warnings from collections import OrderedDict, UserDict @@ -24,7 +26,7 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict +from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict import numpy as np from packaging import version @@ -867,3 +869,36 @@ class LossKwargs(TypedDict, total=False): """ num_items_in_batch: Optional[int] + + +def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool: + """Checks whether a config dict is a timm config dict.""" + return "pretrained_cfg" in config_dict + + +def is_timm_local_checkpoint(pretrained_model_path: str) -> bool: + """ + Checks whether a checkpoint is a timm model checkpoint. + """ + if pretrained_model_path is None: + return False + + # in case it's Path, not str + pretrained_model_path = str(pretrained_model_path) + + is_file = os.path.isfile(pretrained_model_path) + is_dir = os.path.isdir(pretrained_model_path) + + # pretrained_model_path is a file + if is_file and pretrained_model_path.endswith(".json"): + with open(pretrained_model_path, "r") as f: + config_dict = json.load(f) + return is_timm_config_dict(config_dict) + + # pretrained_model_path is a directory with a config.json + if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")): + with open(os.path.join(pretrained_model_path, "config.json"), "r") as f: + config_dict = json.load(f) + return is_timm_config_dict(config_dict) + + return False diff --git a/tests/models/timm_wrapper/__init__.py b/tests/models/timm_wrapper/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py b/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py new file mode 100644 index 000000000000..49d864178d14 --- /dev/null +++ b/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_torchvision, require_vision +from transformers.utils import is_torch_available, is_vision_available + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import TimmWrapperConfig, TimmWrapperImageProcessor + + +@require_torch +@require_vision +@require_torchvision +class TimmWrapperImageProcessingTest(unittest.TestCase): + image_processing_class = TimmWrapperImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.temp_dir = tempfile.TemporaryDirectory() + config = TimmWrapperConfig.from_pretrained("timm/resnet18.a1_in1k") + config.save_pretrained(self.temp_dir.name) + + def tearDown(self): + self.temp_dir.cleanup() + + def test_load_from_hub(self): + image_processor = TimmWrapperImageProcessor.from_pretrained("timm/resnet18.a1_in1k") + self.assertIsInstance(image_processor, TimmWrapperImageProcessor) + + def test_load_from_local_dir(self): + image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name) + self.assertIsInstance(image_processor, TimmWrapperImageProcessor) + + def test_image_processor_properties(self): + image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name) + self.assertTrue(hasattr(image_processor, "data_config")) + self.assertTrue(hasattr(image_processor, "val_transforms")) + self.assertTrue(hasattr(image_processor, "train_transforms")) + + def test_image_processor_call_numpy(self): + image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name) + + single_image = np.random.randint(256, size=(256, 256, 3), dtype=np.uint8) + batch_images = [single_image, single_image, single_image] + + # single image + pixel_values = image_processor(single_image).pixel_values + self.assertEqual(pixel_values.shape, (1, 3, 224, 224)) + + # batch images + pixel_values = image_processor(batch_images).pixel_values + self.assertEqual(pixel_values.shape, (3, 3, 224, 224)) + + def test_image_processor_call_pil(self): + image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name) + + single_image = Image.fromarray(np.random.randint(256, size=(256, 256, 3), dtype=np.uint8)) + batch_images = [single_image, single_image, single_image] + + # single image + pixel_values = image_processor(single_image).pixel_values + self.assertEqual(pixel_values.shape, (1, 3, 224, 224)) + + # batch images + pixel_values = image_processor(batch_images).pixel_values + self.assertEqual(pixel_values.shape, (3, 3, 224, 224)) + + def test_image_processor_call_tensor(self): + image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name) + + single_image = torch.from_numpy(np.random.randint(256, size=(3, 256, 256), dtype=np.uint8)).float() + batch_images = [single_image, single_image, single_image] + + # single image + pixel_values = image_processor(single_image).pixel_values + self.assertEqual(pixel_values.shape, (1, 3, 224, 224)) + + # batch images + pixel_values = image_processor(batch_images).pixel_values + self.assertEqual(pixel_values.shape, (3, 3, 224, 224)) diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py new file mode 100644 index 000000000000..6f63c0aa147d --- /dev/null +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import tempfile +import unittest + +from transformers.testing_utils import ( + require_bitsandbytes, + require_timm, + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel + + +if is_timm_available(): + import timm + + +if is_vision_available(): + from PIL import Image + + from transformers import TimmWrapperImageProcessor + + +class TimmWrapperModelTester: + def __init__( + self, + parent, + model_name="timm/resnet18.a1_in1k", + batch_size=3, + image_size=32, + num_channels=3, + is_training=True, + ): + self.parent = parent + self.model_name = model_name + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.is_training = is_training + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return TimmWrapperConfig.from_pretrained(self.model_name) + + def create_and_check_model(self, config, pixel_values): + model = TimmWrapperModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual( + result.feature_map[-1].shape, + (self.batch_size, model.channels[-1], 14, 14), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +@require_timm +class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (TimmWrapperModel, TimmWrapperForImageClassification) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": TimmWrapperModel, "image-classification": TimmWrapperForImageClassification} + if is_torch_available() + else {} + ) + + test_resize_embeddings = False + test_head_masking = False + test_pruning = False + has_attentions = False + test_model_parallel = False + + def setUp(self): + self.config_class = TimmWrapperConfig + self.model_tester = TimmWrapperModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=self.config_class, + has_text_modality=False, + common_properties=[], + model_name="timm/resnet18.a1_in1k", + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_hidden_states_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # check all hidden states + with torch.no_grad(): + outputs = model(**inputs_dict, output_hidden_states=True) + self.assertTrue( + len(outputs.hidden_states) == 5, f"expected 5 hidden states, but got {len(outputs.hidden_states)}" + ) + expected_shapes = [[16, 16], [8, 8], [4, 4], [2, 2], [1, 1]] + resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states] + self.assertListEqual(expected_shapes, resulted_shapes) + + # check we can select hidden states by indices + with torch.no_grad(): + outputs = model(**inputs_dict, output_hidden_states=[-2, -1]) + self.assertTrue( + len(outputs.hidden_states) == 2, f"expected 2 hidden states, but got {len(outputs.hidden_states)}" + ) + expected_shapes = [[2, 2], [1, 1]] + resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states] + self.assertListEqual(expected_shapes, resulted_shapes) + + @unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="TimmWrapper doesn't support output_attentions=True.") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="TimmWrapper doesn't support this.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="TimmWrapper initialization is managed on the timm side") + def test_initialization(self): + pass + + @unittest.skip(reason="Need to use a timm model and there is no tiny model available.") + def test_model_is_small(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_do_pooling_option(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.do_pooling = False + + model = TimmWrapperModel._from_config(config) + + # check there is no pooling + with torch.no_grad(): + output = model(**inputs_dict) + self.assertIsNone(output.pooler_output) + + # check there is pooler output + with torch.no_grad(): + output = model(**inputs_dict, do_pooling=True) + self.assertIsNotNone(output.pooler_output) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_timm +@require_vision +class TimmWrapperModelIntegrationTest(unittest.TestCase): + # some popular ones + model_names_to_test = [ + "vit_small_patch16_384.augreg_in21k_ft_in1k", + "resnet50.a1_in1k", + "tf_mobilenetv3_large_minimal_100.in1k", + "swin_tiny_patch4_window7_224.ms_in1k", + "ese_vovnet19b_dw.ra_in1k", + "hrnet_w18.ms_aug_in1k", + ] + + @slow + def test_inference_image_classification_head(self): + checkpoint = "timm/resnet18.a1_in1k" + model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval() + image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint) + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the shape and logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_label = 281 # tabby cat + self.assertEqual(torch.argmax(outputs.logits).item(), expected_label) + + expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device) + resulted_slice = outputs.logits[0, :3] + is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3) + self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}") + + @slow + @require_bitsandbytes + def test_inference_image_classification_quantized(self): + from transformers import BitsAndBytesConfig + + checkpoint = "timm/vit_small_patch16_384.augreg_in21k_ft_in1k" + + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + model = TimmWrapperForImageClassification.from_pretrained( + checkpoint, quantization_config=quantization_config, device_map=torch_device + ).eval() + image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint) + + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the shape and logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_label = 281 # tabby cat + self.assertEqual(torch.argmax(outputs.logits).item(), expected_label) + + expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype) + resulted_slice = outputs.logits[0, :3].cpu() + is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1) + self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}") + + @slow + def test_transformers_model_for_classification_is_equivalent_to_timm(self): + # check that wrapper logits are the same as timm model logits + + image = prepare_img() + + for model_name in self.model_names_to_test: + checkpoint = f"timm/{model_name}" + + with self.subTest(msg=model_name): + # prepare inputs + image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint) + pixel_values = image_processor(images=image).pixel_values.to(torch_device) + + # load models + model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval() + timm_model = timm.create_model(model_name, pretrained=True).to(torch_device).eval() + + with torch.inference_mode(): + outputs = model(pixel_values) + timm_outputs = timm_model(pixel_values) + + # check shape is the same + self.assertEqual(outputs.logits.shape, timm_outputs.shape) + + # check logits are the same + diff = (outputs.logits - timm_outputs).max().item() + self.assertLess(diff, 1e-4) + + @slow + def test_transformers_model_is_equivalent_to_timm(self): + # check that wrapper logits are the same as timm model logits + + image = prepare_img() + + models_to_test = ["vit_small_patch16_224.dino"] + self.model_names_to_test + + for model_name in models_to_test: + checkpoint = f"timm/{model_name}" + + with self.subTest(msg=model_name): + # prepare inputs + image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint) + pixel_values = image_processor(images=image).pixel_values.to(torch_device) + + # load models + model = TimmWrapperModel.from_pretrained(checkpoint, device_map=torch_device).eval() + timm_model = timm.create_model(model_name, pretrained=True, num_classes=0).to(torch_device).eval() + + with torch.inference_mode(): + outputs = model(pixel_values) + timm_outputs = timm_model(pixel_values) + + # check shape is the same + self.assertEqual(outputs.pooler_output.shape, timm_outputs.shape) + + # check logits are the same + diff = (outputs.pooler_output - timm_outputs).max().item() + self.assertLess(diff, 1e-4) + + @slow + def test_save_load_to_timm(self): + # test that timm model can be loaded to transformers, saved and then loaded back into timm + + model = TimmWrapperForImageClassification.from_pretrained( + "timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # there is no direct way to load timm model from folder, use the same config + path to weights + timm_model = timm.create_model( + "resnet18", num_classes=10, checkpoint_path=f"{tmpdirname}/model.safetensors" + ) + + # check that all weights are the same after reload + different_weights = [] + for (name1, param1), (name2, param2) in zip( + model.timm_model.named_parameters(), timm_model.named_parameters() + ): + if param1.shape != param2.shape or not torch.equal(param1, param2): + different_weights.append((name1, name2)) + + if different_weights: + self.fail(f"Found different weights after reloading: {different_weights}") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 99d0a8058c67..13eacc4a5965 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3443,6 +3443,7 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): "Data2VecAudioForSequenceClassification", "UniSpeechForSequenceClassification", "PvtForImageClassification", + "TimmWrapperForImageClassification", ] special_param_names = [ r"^bit\.", @@ -3463,6 +3464,7 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): r"^swiftformer\.", r"^swinv2\.", r"^transformers\.models\.swiftformer\.", + r"^timm_model\.", r"^unispeech\.", r"^unispeech_sat\.", r"^vision_model\.", diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index d243dd0c35b6..341fc42b9c68 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -41,6 +41,7 @@ "RagConfig", "SpeechEncoderDecoderConfig", "TimmBackboneConfig", + "TimmWrapperConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig", "LlamaConfig",