🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models#33226
🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models#33226
torch.jit.trace for interpolate_pos_encoding in all vision models#33226Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Failing tests seems unrelated (beam search; nothing to do with this PR). Edit: For 100% backwards compatibility, perhaps we should add some "legacy" config value, and default to the previous version (with |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for opening this PR, the detailed write up and fixing this for all of these models!
Overall the change looks good and it's great to have this finally fixed ❤️
Two comments:
- I'm not sure about the abstraction of the logic into
modeling_vision_utils.pymodule. "modeling_vision_utils" isn't a very well-defined utlity module. Utils have a tendency to gather code-dust and the line between what is a vision utility versus a common pattern for modeling isn't clear-cut. Unlike other utilities e.g. attention masks or rope, which are more model independent and it's clear what type of objects should belong in that file. It's more in-line with transformers to update the logic in all of these models using#Copied fromstatements. - We shouldn't remove the docstrings from the public methods
In this case, as not using |
|
@xenova Hi xenova. I'm highly interested converting and also double-check the ONNX format inference in |
Sure I can do that! I will choose dino as the source for when
Agreed - will add back! @SangbumChoi Using huggingface/optimum#2001: optimum-cli export onnx --model MODEL_ID_GOES_HEREshould work |
|
Addressed comments @amyeroberts and tests pass now (was a flaky fail last time) |
amyeroberts
left a comment
There was a problem hiding this comment.
Beautiful - thanks for fixing this!
torch.jit.trace for interpolate_pos_encoding in all vision modelstorch.jit.trace for interpolate_pos_encoding in all vision models
|
Added 🚨 to PR! Merging! |
|
Hi @xenova thanks for making this PR. Would it be possible to also add this fix for the models added in #29261 (i.e. |
…models (huggingface#33226) * Fix `torch.jit.tracing` for `interpolate_pos_encoding` in all vision models * Apply formatting * Add missing `self.config = config` * Fix copies * Fix hiera interpolation unit test * Formatting * Update `_import_structure` * make style * Fix docstring * Use `# Copied from` instead of utils * DeiT variable renaming (`class_and_dist_pos_embed`) * Fix Hiera `interpolate_pos_encoding`
What does this PR do?
A much needed overhaul of
interpolate_pos_encodingto remove python type casts and use thesizevariant oftorch.nn.interpolateinstead ofscale_factor, which was error prone. This is done by abstracting the function into a separate vision utils file to account for class embeddings, upcasting before interpolation, etc., The option to copy this function across files was there, but this is a lot cleaner (imo) and will prevent such issues in future.This PR has the following benefits:
torch.jit.traceto support dynamic shapes by avoiding python typecasts and ensuring the correct branch is taken for interpolation. Among other things, this means these vision models can be exported to ONNX with dynamic shapes enabled.+ 0.1offset to prevent precision issues (original issue. This was originally done in dinov1 and then corrected in dinov2 (here), which this implementation takes advantage of.self.config=configinSwinEmbeddings+0.1offset, combined with thescale_factorproduced slightly off-center values.Fixes #33181 #32410
Linked PRs in Optimum:
Overview of models
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@amyeroberts @NielsRogge @merveenoyan @qubvel