Skip to content

Conversation

@Factral
Copy link

@Factral Factral commented Jan 18, 2025

What does this PR do?

A few days ago, the PR that adds timm_wrapper was merged #34564 blog post , enabling the use of timm models directly with Hugging Face interfaces, especially the Auto* ones. However, currently the AutoFeatureExtractor interface doesn't work with these models. This PR addresses that gap.

This PR adds timm_wrapper compatibility to AutoFeatureExtractor.from_pretrained(), enabling it to work with fine-tuned/trained timm model checkpoints.

Currently, when using a checkpoint from a trained/fine-tuned timm model (e.g., using examples/pytorch/image-classification/run_image_classification.py), AutoFeatureExtractor.from_pretrained() fails because timm_wrapper is not included in the interface.

While there's a warning about missing preprocessor_config.json in checkpoints, users can manually add it to their checkpoint following examples like https://huggingface.co/Factral/vit_large-model/blob/main/preprocessor_config.json. This PR ensures AutoFeatureExtractor works properly when this file is present.

Changes

  • Added timm_wrapper to AutoFeatureExtractor interface
  • Enables compatibility with timm model checkpoints when preprocessor_config.json is present
  • Added is_timm kwarg in from_dict function

Before submitting

  • Read contributor guidelines
  • Updated documentation to reflect changes
  • Added necessary tests for timm_wrapper functionality

Who can review?

@amyeroberts @qubvel - as this relates to vision models and timm integration

Copy link
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Factral, thanks for submitting the PR!

I would recommend using AutoModel + AutoProcessor to get features from any Timm model. It works without adding preprocessing_config.json. Othrewise, we need to come up with the scheme to import FeatureExtractor the same way, without adding preprocessing_config.json to the repo on Hub, because preprocessing config is stored in config.json for timm models (please see how this was enabled for other AutoProcessors in original PR #34564)

@qubvel
Copy link
Contributor

qubvel commented Jan 20, 2025

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel

checkpoint = "timm/resnet18.a1_in1k"

model = AutoModel.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

# load your image here
image = Image.new("RGB", (224, 224), (255, 0, 0))

inputs = processor(image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

for k, v in outputs.items():
    print(k, v.shape)

# last_hidden_state torch.Size([1, 512, 7, 7])
# pooler_output torch.Size([1, 512])

@qubvel qubvel added the Vision label Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants