Skip to content

Conversation

@qubvel
Copy link
Contributor

@qubvel qubvel commented Nov 1, 2024

What does this PR do?

Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.

Continue of

General Usage

import torch
from urllib.request import urlopen
from PIL import Image
from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor

checkpoint = "timm/resnet50.a1_in1k"
img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

image_processor = AutoImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(img, return_tensors="pt")
model = AutoModelForImageClassification.from_pretrained(checkpoint)

with torch.no_grad():
    logits = model(**inputs).logits

top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)

Pipeline

Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines

import torch
from urllib.request import urlopen
from PIL import Image

from transformers import pipeline

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
pipe = pipeline("image-classification", model="timm/resnet18.a1_in1k")
print(pipe(img))

Trainer

Timm models can now be loaded and trained with the trainer class.

Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans

python run_image_classification.py \                
    --dataset_name beans \
    --output_dir ./beans_outputs/ \
    --remove_unused_columns False \
    --label_column_name labels \
    --do_train \
    --do_eval \
    --push_to_hub \
    --push_to_hub_model_id vit-base-beans \
    --learning_rate 2e-5 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --save_total_limit 3 \
    --seed 1337 \
    --model_name_or_path timm/resnet18.a1_in1k \
    --ignore_mismatched_sizes

Other features enabled

  • Device map:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map="auto")
  • Torch dtype:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, torch_dtype="bfloat16")
  • Quantization:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, load_in_4bit=True)
  • Intermediate hidden states: output_hidden_states=True or output_hidden_states=[1, 2, 3] (to select specific hidden states)
model = TimmWrapperForImageClassification.from_pretrained(checkpoint)
output = model(**intpus, output_hidden_states=True)
  • Transformers TImmWrapper checkpoints are compatible with timm:
model = timm.create_model("hf-hub:qubvel-hf/vit-base-beans", pretrained=True)

TODO

  • Gamma/beta renaming issue
  • Update timm in CI 0.9.6 -> 1.0.11 to enable output_hidden_states tests
    • CI for slow-run takes longer to update images
  • Weights are loaded by transformers instead of timm, which architectures are affected?
  • Tests for image processor

@qubvel qubvel marked this pull request as draft November 1, 2024 15:52
@qubvel qubvel marked this pull request as ready for review December 2, 2024 10:58
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to look nice!

>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice that there is no kwargs or whatever to load the model

# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with this

Comment on lines -636 to -684
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should eventually be completely removed, cc @ArthurZucker

Comment on lines 3910 to 3911
if metadata is None:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment saying that in case of no metadata, it's seen as a pytorch checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added 7e0d2c6

Comment on lines 417 to 420
default_image_processor_filename = (
"config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME
)
kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use CONFIG_NAME instead

num_items_in_batch: Optional[int]


def is_timm_hub_checkpoint(pretrained_model_name_or_path: str) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the method's objective, I would have it accept only a pretrained_model_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function removed with ff6efde (see comment below)

Comment on lines 879 to 881
if os.path.isfile(pretrained_model_name_or_path) or os.path.isdir(pretrained_model_name_or_path):
return False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I'd therefore remove this

Copy link
Contributor Author

@qubvel qubvel Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function removed with ff6efde (see comment below)

Comment on lines 882 to 884
return pretrained_model_name_or_path.startswith("hf-hub:timm/") or pretrained_model_name_or_path.startswith(
"timm/"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good as long as we don't expect community checkpoints, but there are community checkpoints already, for example see the following: https://huggingface.co/prov-gigapath/prov-gigapath

I think we'll need a more robust check here

Copy link
Contributor Author

@qubvel qubvel Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is a weak assumption. This function is only needed in the image processors auto class to load it from the config. In AutoImageProcessor.from_pretrained the only information we have is the model name, so the only way to make a robust check is to load files from the Hub.

I removed this function entirely and instead made a fallback for loading the timm image processor dict in the auto class of the image processor. To avoid loading the config for every model, I did it in the following way:

  1. Try to load the image processor config as usual - most of the models will be fine, and we won't have any overhead here.
  2. In case of an exception, try loading config.json and check if it's a timm checkpoint.

See ff6efde for details. I documented it in the code, let me know if you have doubts about this approach.

Works with

image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better to me!

def test_model_is_small(self):
pass

# Overriding as output_attentions is not supported by TimmWrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed in a476610


@require_torch
@require_vision
class TimmWrapperModelIntegrationTest(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should require timm as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 327095a

@LysandreJik
Copy link
Member

Thanks, this looks good! cc @molbap can you give the processor code a quick look just to double check?

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, took a quick look at the processor and ran it, found some stufff which I commented! Also looked at the whole PR, real nice work!

local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand - since it's a new kwarg that is does not seem to have an equivalent in hub methods (like use_auth_token or revision) I'd add a small docstring to advertise it

Comment on lines +428 to +444
try:
# Main path for all transformers models and local TimmWrapper checkpoints
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
except Exception as initial_exception:
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
# load `config.json` and if it fails with some error, we raise the initial exception.
try:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
)
except Exception:
raise initial_exception

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A remark here - that's true for any processor with a preprocessor_config.json, but would it make sense to sanitize the inputs a bit? On community checkpoints, there's extra keys that are unused, for instance model_args that contains duplicated information.
(I'm asking because we're already using a try/catch pattern here so that allows some branching)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I got it, can you provide more details, please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically I I load https://huggingface.co/prov-gigapath/prov-gigapath/blob/main/config.json, I'll end up with an ImageProcessor object that has keys I can't make much use of in transformers, such as model_args, so I wondered if it made sense to filter the contents of that config.json to an expected schema

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it makes sense now. I made it on the image-processor init level, see fd7b646

Is it what you had in mind?

>>> from transformers import AutoImageProcessor, AutoConfig

>>> image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")
>>> print(image_processor)

TimmWrapperImageProcessor {
  "architecture": "vit_giant_patch14_dinov2",
  "data_config": {
    "crop_mode": "center",
    "crop_pct": 1.0,
    "input_size": [
      3,
      224,
      224
    ],
    "interpolation": "bicubic",
    "mean": [
      0.485,
      0.456,
      0.406
    ],
    "std": [
      0.229,
      0.224,
      0.225
    ]
  },
  "image_processor_type": "TimmWrapperImageProcessor"
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly! that looks good

Comment on lines +96 to +99
if isinstance(images, torch.Tensor):
images = self.val_transforms(images)
# Add batch dimension if a single image
images = images.unsqueeze(0) if images.ndim == 3 else images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, if val_transforms is for instance

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

Then the ToTensor op will fail, since F.to_tensor does expect a PIL image IIRC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, since timm>=1.0.8 its MaybeToTensor and works fine.. but on timm<1.0.8 it's indeed raising an error

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

I can add something like

if timm < 1.0.8 and isinstance(images, torch.Tensor):
    images = images.cpu().numpy()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in a10bc0d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel version checks tend to be brittle, can you do a hasattr on the MaybeToTensor class existing? I think that should line up with its use in the transforms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwightman, thanks for the note. I added a check for the class name, it looks not that elegant but I hope it's more robust 3d1a76e

@qubvel
Copy link
Contributor Author

qubvel commented Dec 4, 2024

@LysandreJik @molbap @rwightman Thanks for the reviews! I believe all comments have been addressed. Do you have anything else in mind? It would be nice to move it forward.

@LysandreJik
Copy link
Member

Ok awesome! At this point just a second quick look from @ArthurZucker and we're good

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice ! This is kind of a perfect integration with Auto API, congrats!

Comment on lines 341 to 345
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do something explicit for this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you meant here, but I made it a bit more readable in 4edfe90 IMO. It's actually not related to the timm wrapper, it's the same in the original code.

Comment on lines -636 to -684
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep there is #33192 kinda related

@qubvel qubvel merged commit 5fcf628 into huggingface:main Dec 11, 2024
27 checks passed
Cemberk pushed a commit to ROCm/transformers that referenced this pull request Mar 6, 2025
* Add files

* Init

* Add TimmWrapperModel

* Fix up

* Some fixes

* Fix up

* Remove old file

* Sort out import orders

* Fix some model loading

* Compatible with pipeline and trainer

* Fix up

* Delete test_timm_model_1/config.json

* Remove accidentally commited files

* Delete src/transformers/models/modeling_timm_wrapper.py

* Remove empty imports; fix transformations applied

* Tidy up

* Add image classifcation model to special cases

* Create pretrained model; enable device_map='auto'

* Enable most tests; fix init order

* Sort imports

* [run-slow] timm_wrapper

* Pass num_classes into timm.create_model

* Remove train transforms from image processor

* Update timm creation with pretrained=False

* Fix gamma/beta issue for timm models

* Fixing gamma and beta renaming for timm models

* Simplify config and model creation

* Remove attn_implementation diff

* Fixup

* Docstrings

* Fix warning msg text according to test case

* Fix device_map auto

* Set dtype and device for pixel_values in forward

* Enable output hidden states

* Enable tests for hidden_states and model parallel

* Remove default scriptable arg

* Refactor inner model

* Update timm version

* Fix _find_mismatched_keys function

* Change inheritance for Classification model (fix weights loading with device_map)

* Minor bugfix

* Disable save pretrained for image processor

* Rename hook method for loaded keys correction

* Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm`

* Managing num_labels <-> num_classes attributes

* Enable loading checkpoints in Trainer to resume training

* Update error message for output_hidden_states

* Add output hidden states test

* Decouple base and classification models

* Add more test cases

* Add save-load-to-timm test

* Fix test name

* Fixup

* Add do_pooling

* Add test for do_pooling

* Fix doc

* Add tests for TimmWrapperModel

* Add validation for `num_classes=0` in timm config + test for DINO checkpoint

* Adjust atol for test

* Fix docs

* dev-ci

* dev-ci

* Add tests for image processor

* Update docs

* Update init to new format

* Update docs in configuration

* Fix some docs in image processor

* Improve docs for modeling

* fix for is_timm_checkpoint

* Update code examples

* Fix header

* Fix typehint

* Increase tolerance a bit

* Fix Path

* Fixing model parallel tests

* Disable "parallel" tests

* Add comment for metadata

* Refactor AutoImageProcessor for timm wrapper loading

* Remove custom test_model_outputs_equivalence

* Add require_timm decorator

* Fix comment

* Make image processor work with older timm versions and tensor input

* Save config instead of whole model in image processor tests

* Add docstring for `image_processor_filename`

* Sanitize kwargs for timm image processor

* Fix doc style

* Update check for tensor input

* Update normalize

* Remove _load_timm_model function

---------

Co-authored-by: Amy Roberts <[email protected]>
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it defaults to "preprocessor_config.json"?

Copy link
Contributor Author

@qubvel qubvel Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongwhatamazon, yes, you are right! Feel free to propose a PR to fix if you have bandwidth

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants