-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Add TimmWrapper #34564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TimmWrapper #34564
Conversation
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is starting to look nice!
| >>> # Load model and image processor | ||
| >>> checkpoint = "timm/resnet50.a1_in1k" | ||
| >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) | ||
| >>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice that there is no kwargs or whatever to load the model
| # Check format of the archive | ||
| with safe_open(checkpoint_file, framework="pt") as f: | ||
| metadata = f.metadata() | ||
| if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok with this
| # Convert old format to new format if needed from a PyTorch state_dict | ||
| old_keys = [] | ||
| new_keys = [] | ||
| renamed_keys = {} | ||
| renamed_gamma = {} | ||
| renamed_beta = {} | ||
| warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " | ||
| for key in state_dict.keys(): | ||
| new_key = None | ||
| if "gamma" in key: | ||
| # We add only the first key as an example | ||
| new_key = key.replace("gamma", "weight") | ||
| renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma | ||
| if "beta" in key: | ||
| # We add only the first key as an example | ||
| new_key = key.replace("beta", "bias") | ||
| renamed_beta[key] = new_key if not renamed_beta else renamed_beta | ||
| if new_key: | ||
| old_keys.append(key) | ||
| new_keys.append(new_key) | ||
| renamed_keys = {**renamed_gamma, **renamed_beta} | ||
| if renamed_keys: | ||
| warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
| for old_key, new_key in renamed_keys.items(): | ||
| warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
| warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
| logger.info_once(warning_msg) | ||
| for old_key, new_key in zip(old_keys, new_keys): | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should eventually be completely removed, cc @ArthurZucker
| if metadata is None: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a comment saying that in case of no metadata, it's seen as a pytorch checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added 7e0d2c6
| default_image_processor_filename = ( | ||
| "config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME | ||
| ) | ||
| kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use CONFIG_NAME instead
src/transformers/utils/generic.py
Outdated
| num_items_in_batch: Optional[int] | ||
|
|
||
|
|
||
| def is_timm_hub_checkpoint(pretrained_model_name_or_path: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the method's objective, I would have it accept only a pretrained_model_name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function removed with ff6efde (see comment below)
src/transformers/utils/generic.py
Outdated
| if os.path.isfile(pretrained_model_name_or_path) or os.path.isdir(pretrained_model_name_or_path): | ||
| return False | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I'd therefore remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function removed with ff6efde (see comment below)
src/transformers/utils/generic.py
Outdated
| return pretrained_model_name_or_path.startswith("hf-hub:timm/") or pretrained_model_name_or_path.startswith( | ||
| "timm/" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good as long as we don't expect community checkpoints, but there are community checkpoints already, for example see the following: https://huggingface.co/prov-gigapath/prov-gigapath
I think we'll need a more robust check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this is a weak assumption. This function is only needed in the image processors auto class to load it from the config. In AutoImageProcessor.from_pretrained the only information we have is the model name, so the only way to make a robust check is to load files from the Hub.
I removed this function entirely and instead made a fallback for loading the timm image processor dict in the auto class of the image processor. To avoid loading the config for every model, I did it in the following way:
- Try to load the image processor config as usual - most of the models will be fine, and we won't have any overhead here.
- In case of an exception, try loading config.json and check if it's a timm checkpoint.
See ff6efde for details. I documented it in the code, let me know if you have doubts about this approach.
Works with
image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better to me!
| def test_model_is_small(self): | ||
| pass | ||
|
|
||
| # Overriding as output_attentions is not supported by TimmWrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, removed in a476610
|
|
||
| @require_torch | ||
| @require_vision | ||
| class TimmWrapperModelIntegrationTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should require timm as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in 327095a
|
Thanks, this looks good! cc @molbap can you give the processor code a quick look just to double check? |
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, took a quick look at the processor and ran it, found some stufff which I commented! Also looked at the whole PR, real nice work!
| local_files_only = kwargs.pop("local_files_only", False) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", "") | ||
| image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand - since it's a new kwarg that is does not seem to have an equivalent in hub methods (like use_auth_token or revision) I'd add a small docstring to advertise it
| try: | ||
| # Main path for all transformers models and local TimmWrapper checkpoints | ||
| config_dict, _ = ImageProcessingMixin.get_image_processor_dict( | ||
| pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs | ||
| ) | ||
| except Exception as initial_exception: | ||
| # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json` | ||
| # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information | ||
| # except the model name, the only way to check if a remote checkpoint is a timm model is to try to | ||
| # load `config.json` and if it fails with some error, we raise the initial exception. | ||
| try: | ||
| config_dict, _ = ImageProcessingMixin.get_image_processor_dict( | ||
| pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs | ||
| ) | ||
| except Exception: | ||
| raise initial_exception | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A remark here - that's true for any processor with a preprocessor_config.json, but would it make sense to sanitize the inputs a bit? On community checkpoints, there's extra keys that are unused, for instance model_args that contains duplicated information.
(I'm asking because we're already using a try/catch pattern here so that allows some branching)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I got it, can you provide more details, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically I I load https://huggingface.co/prov-gigapath/prov-gigapath/blob/main/config.json, I'll end up with an ImageProcessor object that has keys I can't make much use of in transformers, such as model_args, so I wondered if it made sense to filter the contents of that config.json to an expected schema
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it makes sense now. I made it on the image-processor init level, see fd7b646
Is it what you had in mind?
>>> from transformers import AutoImageProcessor, AutoConfig
>>> image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")
>>> print(image_processor)
TimmWrapperImageProcessor {
"architecture": "vit_giant_patch14_dinov2",
"data_config": {
"crop_mode": "center",
"crop_pct": 1.0,
"input_size": [
3,
224,
224
],
"interpolation": "bicubic",
"mean": [
0.485,
0.456,
0.406
],
"std": [
0.229,
0.224,
0.225
]
},
"image_processor_type": "TimmWrapperImageProcessor"
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, exactly! that looks good
| if isinstance(images, torch.Tensor): | ||
| images = self.val_transforms(images) | ||
| # Add batch dimension if a single image | ||
| images = images.unsqueeze(0) if images.ndim == 3 else images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here, if val_transforms is for instance
Compose(
Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)Then the ToTensor op will fail, since F.to_tensor does expect a PIL image IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, since timm>=1.0.8 its MaybeToTensor and works fine.. but on timm<1.0.8 it's indeed raising an error
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
I can add something like
if timm < 1.0.8 and isinstance(images, torch.Tensor):
images = images.cpu().numpy()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in a10bc0d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel version checks tend to be brittle, can you do a hasattr on the MaybeToTensor class existing? I think that should line up with its use in the transforms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rwightman, thanks for the note. I added a check for the class name, it looks not that elegant but I hope it's more robust 3d1a76e
|
@LysandreJik @molbap @rwightman Thanks for the reviews! I believe all comments have been addressed. Do you have anything else in mind? It would be nice to move it forward. |
|
Ok awesome! At this point just a second quick look from @ArthurZucker and we're good |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice ! This is kind of a perfect integration with Auto API, congrats!
| normalize = ( | ||
| Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | ||
| if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") | ||
| else Lambda(lambda x: x) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do something explicit for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you meant here, but I made it a bit more readable in 4edfe90 IMO. It's actually not related to the timm wrapper, it's the same in the original code.
| # Convert old format to new format if needed from a PyTorch state_dict | ||
| old_keys = [] | ||
| new_keys = [] | ||
| renamed_keys = {} | ||
| renamed_gamma = {} | ||
| renamed_beta = {} | ||
| warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " | ||
| for key in state_dict.keys(): | ||
| new_key = None | ||
| if "gamma" in key: | ||
| # We add only the first key as an example | ||
| new_key = key.replace("gamma", "weight") | ||
| renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma | ||
| if "beta" in key: | ||
| # We add only the first key as an example | ||
| new_key = key.replace("beta", "bias") | ||
| renamed_beta[key] = new_key if not renamed_beta else renamed_beta | ||
| if new_key: | ||
| old_keys.append(key) | ||
| new_keys.append(new_key) | ||
| renamed_keys = {**renamed_gamma, **renamed_beta} | ||
| if renamed_keys: | ||
| warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
| for old_key, new_key in renamed_keys.items(): | ||
| warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
| warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
| logger.info_once(warning_msg) | ||
| for old_key, new_key in zip(old_keys, new_keys): | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep there is #33192 kinda related
* Add files * Init * Add TimmWrapperModel * Fix up * Some fixes * Fix up * Remove old file * Sort out import orders * Fix some model loading * Compatible with pipeline and trainer * Fix up * Delete test_timm_model_1/config.json * Remove accidentally commited files * Delete src/transformers/models/modeling_timm_wrapper.py * Remove empty imports; fix transformations applied * Tidy up * Add image classifcation model to special cases * Create pretrained model; enable device_map='auto' * Enable most tests; fix init order * Sort imports * [run-slow] timm_wrapper * Pass num_classes into timm.create_model * Remove train transforms from image processor * Update timm creation with pretrained=False * Fix gamma/beta issue for timm models * Fixing gamma and beta renaming for timm models * Simplify config and model creation * Remove attn_implementation diff * Fixup * Docstrings * Fix warning msg text according to test case * Fix device_map auto * Set dtype and device for pixel_values in forward * Enable output hidden states * Enable tests for hidden_states and model parallel * Remove default scriptable arg * Refactor inner model * Update timm version * Fix _find_mismatched_keys function * Change inheritance for Classification model (fix weights loading with device_map) * Minor bugfix * Disable save pretrained for image processor * Rename hook method for loaded keys correction * Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm` * Managing num_labels <-> num_classes attributes * Enable loading checkpoints in Trainer to resume training * Update error message for output_hidden_states * Add output hidden states test * Decouple base and classification models * Add more test cases * Add save-load-to-timm test * Fix test name * Fixup * Add do_pooling * Add test for do_pooling * Fix doc * Add tests for TimmWrapperModel * Add validation for `num_classes=0` in timm config + test for DINO checkpoint * Adjust atol for test * Fix docs * dev-ci * dev-ci * Add tests for image processor * Update docs * Update init to new format * Update docs in configuration * Fix some docs in image processor * Improve docs for modeling * fix for is_timm_checkpoint * Update code examples * Fix header * Fix typehint * Increase tolerance a bit * Fix Path * Fixing model parallel tests * Disable "parallel" tests * Add comment for metadata * Refactor AutoImageProcessor for timm wrapper loading * Remove custom test_model_outputs_equivalence * Add require_timm decorator * Fix comment * Make image processor work with older timm versions and tensor input * Save config instead of whole model in image processor tests * Add docstring for `image_processor_filename` * Sanitize kwargs for timm image processor * Fix doc style * Update check for tensor input * Update normalize * Remove _load_timm_model function --------- Co-authored-by: Amy Roberts <[email protected]>
| subfolder (`str`, *optional*, defaults to `""`): | ||
| In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | ||
| specify the folder name here. | ||
| image_processor_filename (`str`, *optional*, defaults to `"config.json"`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it defaults to "preprocessor_config.json"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hongwhatamazon, yes, you are right! Feel free to propose a PR to fix if you have bandwidth
What does this PR do?
Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.
Continue of
General Usage
Pipeline
Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines
Trainer
Timm models can now be loaded and trained with the trainer class.
Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans
Other features enabled
output_hidden_states=Trueoroutput_hidden_states=[1, 2, 3](to select specific hidden states)TODO
output_hidden_statesteststransformersinstead oftimm, which architectures are affected?