Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor swin transfomer so later we can reuse component for 3d version #6088

Merged

Conversation

YosuaMichael
Copy link
Contributor

@YosuaMichael YosuaMichael commented May 25, 2022

Refactoring swin_transformer so we can reuse some components for video swin transformer.
Discussion are in: facebookresearch/multimodal#43 (comment)

In this change, we need to convert the previous weight to adapt the new structure. After converting, I did test on all the weights and here are the validation script and result:

python -u ~/script/run_with_submitit.py \
    --timeout 3000 --nodes 1 --ngpus 1 --batch-size=1 \
    --partition train --model swin_s \
    --data-path="/datasets01_ontap/imagenet_full_size/061417" \
    --weights="Swin_S_Weights.IMAGENET1K_V1" \
    --test-only 

# Test:  Acc@1 81.470 Acc@5 95.776
# Previous accuracy before refactoring: Acc@1 81.474 Acc@5 95.776

python ~/script/run_with_submitit.py \
    --timeout 3000 --nodes 1 --ngpus 1 --batch-size=1 \
    --partition train --model swin_t \
    --data-path="/datasets01_ontap/imagenet_full_size/061417" \
    --weights="Swin_T_Weights.IMAGENET1K_V1" \
    --test-only 

# Test:  Acc@1 83.196 Acc@5 96.362
# Previous accuracy before refactoring: Acc@1 83.196 Acc@5 96.360

python -u ~/script/run_with_submitit.py \
    --timeout 3000 --nodes 1 --ngpus 1 --batch-size=1 \
    --partition train --model swin_b \
    --data-path="/datasets01_ontap/imagenet_full_size/061417" \
    --weights="Swin_B_Weights.IMAGENET1K_V1" \
    --test-only 

# Test:  Acc@1 83.584 Acc@5 96.640
# Previous accuracy before refactoring: Acc@1 83.582 Acc@5 96.640

Overall it have some minor changes (+- 0.004%) that I think still acceptable.

@YosuaMichael
Copy link
Contributor Author

I notice the failed test on jit / fx, will investigate on it

@datumbox
Copy link
Contributor

@xiaohu2015 I wonder if you have the bandwidth to check the proposed updates in this PR.

To provide some context, Yosua is making some small adjustments to the original implementation so that we can support easier the 3D version of the Swin model. Please let us know if you have any thoughts.

torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
@YosuaMichael YosuaMichael marked this pull request as ready for review May 25, 2022 22:15
@YosuaMichael
Copy link
Contributor Author

@xiaohu2015 I wonder if you have the bandwidth to check the proposed updates in this PR.

To provide some context, Yosua is making some small adjustments to the original implementation so that we can support easier the 3D version of the Swin model. Please let us know if you have any thoughts.

Hi @xiaohu2015 , would be great if you can take a look at this PR :)
To add more info, currently I added the 3D version of Swin Transformer model on torchmultimodal: https://github.com/facebookresearch/multimodal/blob/44786d8743d75fdd30ba79adc283a5eaaa3ecfca/torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
The plan is to add this into torchvision later and we create this PR to refactor torchvision 2d swin transformer so we can reuse some of the components and generally make it easier to upstream 3D version later.

@YosuaMichael YosuaMichael changed the title [WIP] Refactor swin transfomer so later we can reuse component for 3d version Refactor swin transfomer so later we can reuse component for 3d version May 25, 2022
Copy link
Contributor Author

@YosuaMichael YosuaMichael left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the following changes:

  • Use List[int] for params like window_size, shift_size and patch_size
  • Make PatchMerging and SwinTransformerBlock able to be used in both 2d and 3d
  • Separate PatchEmbed from SwinTransformer class and enable user to change it in constructor param
  • Make num_classes optional, and enable the user to get the model without head if they set it to None
  • Update the method to handle edge cases where window_size is larger than input
  • Change weight url to the ported weight (just ported and not retrained)

@@ -428,7 +506,7 @@ class Swin_T_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 81.474,
"acc@1": 81.470,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The acc@1 of swin_t decrease by 0.004%, but I think this is because of some randomness as the acc@5 for swin_s and acc@1 for swin_b each increase by 0.002%

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to change this. Every time you run the analysis you might get a slightly different number on the 3rd decimal and there is not much you can do about it.

torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
@xiaohu2015
Copy link
Contributor

xiaohu2015 commented May 26, 2022

@YosuaMichael All your modifcation looks good for me. But I have one question about the Update the method to handle edge cases where window_size is larger than input
In the original code, we will pad the input and set shift_size=0 but not change the window size. But you choose to change the window size but not pad the input.

the two ways maybe different, can have a impact on the performance (but don't affect current model, we don't have such situation)

@YosuaMichael
Copy link
Contributor Author

@YosuaMichael All your modifcation looks good for me. But I have one question about the Update the method to handle edge cases where window_size is larger than input
In the original code, we will pad the input and set shift_size=0 but not change the window size. But you choose to change the window size but not pad the input.

Hi @xiaohu2015 , thanks for taking a look :)
I think the original code also modify the window_size when it is larger than input_size. Here is the part I refer to: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195

@xiaohu2015
Copy link
Contributor

xiaohu2015 commented May 26, 2022

@YosuaMichael All your modifcation looks good for me. But I have one question about the Update the method to handle edge cases where window_size is larger than input
In the original code, we will pad the input and set shift_size=0 but not change the window size. But you choose to change the window size but not pad the input.

Hi @xiaohu2015 , thanks for taking a look :) I think the original code also modify the window_size when it is larger than input_size. Here is the part I refer to: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195

You are right, I also think the second way is more reasonable ( For 192x192 input, we should use window-size=6 in the last stage in swinv2).

@xiaohu2015
Copy link
Contributor

@YosuaMichael I think you maybe want to implement the https://arxiv.org/pdf/2106.13230.pdf? It should use 3D windows? maybe I missed some thing?

@YosuaMichael
Copy link
Contributor Author

YosuaMichael commented May 26, 2022

@YosuaMichael I think you maybe want to implement the https://arxiv.org/pdf/2106.13230.pdf? It should use 3D windows? maybe I missed some thing?

@xiaohu2015 Yes, you are right! I want to implement Video Swin Transformer , currently I added it on torchmultimodal repository first (see this PR) but I plan to upstream it into torchvision as video_classifier after that.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @YosuaMichael, just a few comments below

torchvision/models/swin_transformer.py Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
@@ -428,7 +506,7 @@ class Swin_T_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 81.474,
"acc@1": 81.470,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to change this. Every time you run the analysis you might get a slightly different number on the 3rd decimal and there is not much you can do about it.

torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @YosuaMichael.

Can you please provide proof for new inference runs to confirm the accuracy of the models remain unaffected?

We can merge on Green CI.

torchvision/models/swin_transformer.py Show resolved Hide resolved
@YosuaMichael
Copy link
Contributor Author

YosuaMichael commented May 26, 2022

@datumbox I have rerun the validation script and it return the same exact result as when I ran before (the one in PR description). Will merge now

@YosuaMichael YosuaMichael merged commit 952f480 into pytorch:main May 26, 2022
@github-actions
Copy link

Hey @YosuaMichael!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

YosuaMichael added a commit to YosuaMichael/vision that referenced this pull request May 26, 2022
…on (pytorch#6088)

* Use List[int] instead of int for window_size and shift_size

* Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases

* Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None

* Dont use if before padding so it is fx friendly

* Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing

* Update the weight url to the converted weight with new structure

* Update the accuracy of swin_transformer

* Change assert to Exception and nit

* Make num_classes optional

* Add typing output for _fix_window_and_shift_size function

* init head to None to make it jit scriptable

* Revert the change to make num_classes optional

* Revert unneccesarry changes that might be risky

* Remove self.head declaration
@xiaohu2015
Copy link
Contributor

@YosuaMichael All your modifcation looks good for me. But I have one question about the Update the method to handle edge cases where window_size is larger than input
In the original code, we will pad the input and set shift_size=0 but not change the window size. But you choose to change the window size but not pad the input.

Hi @xiaohu2015 , thanks for taking a look :) I think the original code also modify the window_size when it is larger than input_size. Here is the part I refer to: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195

why you revert this modifcation?

YosuaMichael added a commit that referenced this pull request May 31, 2022
…on (#6088) (#6100)

* Use List[int] instead of int for window_size and shift_size

* Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases

* Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None

* Dont use if before padding so it is fx friendly

* Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing

* Update the weight url to the converted weight with new structure

* Update the accuracy of swin_transformer

* Change assert to Exception and nit

* Make num_classes optional

* Add typing output for _fix_window_and_shift_size function

* init head to None to make it jit scriptable

* Revert the change to make num_classes optional

* Revert unneccesarry changes that might be risky

* Remove self.head declaration
facebook-github-bot pushed a commit that referenced this pull request Jun 1, 2022
… 3d version (#6088)

Summary:
* Use List[int] instead of int for window_size and shift_size

* Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases

* Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None

* Dont use if before padding so it is fx friendly

* Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing

* Update the weight url to the converted weight with new structure

* Update the accuracy of swin_transformer

* Change assert to Exception and nit

* Make num_classes optional

* Add typing output for _fix_window_and_shift_size function

* init head to None to make it jit scriptable

* Revert the change to make num_classes optional

* Revert unneccesarry changes that might be risky

* Remove self.head declaration

Reviewed By: NicolasHug

Differential Revision: D36760917

fbshipit-source-id: 920177e069913775773c45e19abeb32017faaaee
@xiaohu2015
Copy link
Contributor

@YosuaMichael All your modifcation looks good for me. But I have one question about the Update the method to handle edge cases where window_size is larger than input
In the original code, we will pad the input and set shift_size=0 but not change the window size. But you choose to change the window size but not pad the input.

Hi @xiaohu2015 , thanks for taking a look :) I think the original code also modify the window_size when it is larger than input_size. Here is the part I refer to: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195

why you revert this modifcation?

It maybe important for swin_v2 model, since they use 192x192 input, the window_size in the last stage will be 6 rather than 7

@datumbox
Copy link
Contributor

datumbox commented Jun 30, 2022

@xiaohu2015 Thanks for the feedback. sorry we missed your message!

@YosuaMichael could you please provide the details?

@YosuaMichael
Copy link
Contributor Author

@xiaohu2015 Sorry for the late reply and thanks for the feedback.

The reason why we revert the changes is because we want to cherrypick this changes in the new release of torchvision 0.13.0, at that time we decide to not change any previous behaviour but only do more "safe" changes that aim for reusing components to swin_transformer 3d.

I think we can put it back now with another PR, I will create a PR for this and tag you!

@YosuaMichael
Copy link
Contributor Author

@xiaohu2015 I create the changes on this PR: #6222

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants