-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor swin transfomer so later we can reuse component for 3d version #6088
Refactor swin transfomer so later we can reuse component for 3d version #6088
Conversation
… without head by specifying num_heads=None
I notice the failed test on jit / fx, will investigate on it |
@xiaohu2015 I wonder if you have the bandwidth to check the proposed updates in this PR. To provide some context, Yosua is making some small adjustments to the original implementation so that we can support easier the 3D version of the Swin model. Please let us know if you have any thoughts. |
…swin-transformer-breaking
…rap with torch.fx.wrap so it is excluded from tracing
…mer-breaking [WIP] Part of refactoring swin transformer that require re-training the model
Hi @xiaohu2015 , would be great if you can take a look at this PR :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do the following changes:
- Use
List[int]
for params likewindow_size
,shift_size
andpatch_size
- Make
PatchMerging
andSwinTransformerBlock
able to be used in both 2d and 3d - Separate
PatchEmbed
fromSwinTransformer
class and enable user to change it in constructor param - Make
num_classes
optional, and enable the user to get the model without head if they set it toNone
- Update the method to handle edge cases where
window_size
is larger thaninput
- Change weight url to the ported weight (just ported and not retrained)
@@ -428,7 +506,7 @@ class Swin_T_Weights(WeightsEnum): | |||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", | |||
"_metrics": { | |||
"ImageNet-1K": { | |||
"acc@1": 81.474, | |||
"acc@1": 81.470, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The acc@1 of swin_t decrease by 0.004%, but I think this is because of some randomness as the acc@5 for swin_s and acc@1 for swin_b each increase by 0.002%
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to change this. Every time you run the analysis you might get a slightly different number on the 3rd decimal and there is not much you can do about it.
@YosuaMichael All your modifcation looks good for me. But I have one question about the the two ways maybe different, can have a impact on the performance (but don't affect current model, we don't have such situation) |
Hi @xiaohu2015 , thanks for taking a look :) |
You are right, I also think the second way is more reasonable ( For 192x192 input, we should use window-size=6 in the last stage in swinv2). |
@YosuaMichael I think you maybe want to implement the https://arxiv.org/pdf/2106.13230.pdf? It should use 3D windows? maybe I missed some thing? |
@xiaohu2015 Yes, you are right! I want to implement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @YosuaMichael, just a few comments below
@@ -428,7 +506,7 @@ class Swin_T_Weights(WeightsEnum): | |||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", | |||
"_metrics": { | |||
"ImageNet-1K": { | |||
"acc@1": 81.474, | |||
"acc@1": 81.470, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to change this. Every time you run the analysis you might get a slightly different number on the 3rd decimal and there is not much you can do about it.
…chael/vision into models/refactor-swin-transformer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @YosuaMichael.
Can you please provide proof for new inference runs to confirm the accuracy of the models remain unaffected?
We can merge on Green CI.
@datumbox I have rerun the validation script and it return the same exact result as when I ran before (the one in PR description). Will merge now |
Hey @YosuaMichael! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
…on (pytorch#6088) * Use List[int] instead of int for window_size and shift_size * Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases * Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None * Dont use if before padding so it is fx friendly * Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing * Update the weight url to the converted weight with new structure * Update the accuracy of swin_transformer * Change assert to Exception and nit * Make num_classes optional * Add typing output for _fix_window_and_shift_size function * init head to None to make it jit scriptable * Revert the change to make num_classes optional * Revert unneccesarry changes that might be risky * Remove self.head declaration
why you revert this modifcation? |
…on (#6088) (#6100) * Use List[int] instead of int for window_size and shift_size * Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases * Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None * Dont use if before padding so it is fx friendly * Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing * Update the weight url to the converted weight with new structure * Update the accuracy of swin_transformer * Change assert to Exception and nit * Make num_classes optional * Add typing output for _fix_window_and_shift_size function * init head to None to make it jit scriptable * Revert the change to make num_classes optional * Revert unneccesarry changes that might be risky * Remove self.head declaration
… 3d version (#6088) Summary: * Use List[int] instead of int for window_size and shift_size * Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases * Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None * Dont use if before padding so it is fx friendly * Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing * Update the weight url to the converted weight with new structure * Update the accuracy of swin_transformer * Change assert to Exception and nit * Make num_classes optional * Add typing output for _fix_window_and_shift_size function * init head to None to make it jit scriptable * Revert the change to make num_classes optional * Revert unneccesarry changes that might be risky * Remove self.head declaration Reviewed By: NicolasHug Differential Revision: D36760917 fbshipit-source-id: 920177e069913775773c45e19abeb32017faaaee
It maybe important for swin_v2 model, since they use 192x192 input, the window_size in the last stage will be 6 rather than 7 |
@xiaohu2015 Thanks for the feedback. sorry we missed your message! @YosuaMichael could you please provide the details? |
@xiaohu2015 Sorry for the late reply and thanks for the feedback. The reason why we revert the changes is because we want to cherrypick this changes in the new release of torchvision 0.13.0, at that time we decide to not change any previous behaviour but only do more "safe" changes that aim for reusing components to swin_transformer 3d. I think we can put it back now with another PR, I will create a PR for this and tag you! |
@xiaohu2015 I create the changes on this PR: #6222 |
Refactoring swin_transformer so we can reuse some components for video swin transformer.
Discussion are in: facebookresearch/multimodal#43 (comment)
In this change, we need to convert the previous weight to adapt the new structure. After converting, I did test on all the weights and here are the validation script and result:
Overall it have some minor changes (+- 0.004%) that I think still acceptable.