Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SwinV2 #6246

Merged
merged 61 commits into from
Aug 10, 2022
Merged

Add SwinV2 #6246

merged 61 commits into from
Aug 10, 2022

Conversation

ain-soph
Copy link
Contributor

@ain-soph ain-soph commented Jul 7, 2022

Fixes #6242

@ain-soph
Copy link
Contributor Author

ain-soph commented Jul 7, 2022

And I see an issue from the Microsoft SwinTransformer repo: microsoft/Swin-Transformer#194

It thinks it's not necessary to divide the mask into 9 parts because 4 parts are already enough.
I kind of agree with that. Anyone has opinion on that?

@ain-soph
Copy link
Contributor Author

ain-soph commented Jul 8, 2022

It seems these mod operations are considered as nodes of the fx graph, and their output is int rather than torch.Tensor

Previous SwinTransformer V1 doesn't fail at these tests only because those mod operations are not sampled with that random seed.

image

While this test file requires it to be torch.Tensor
image

@datumbox
Copy link
Contributor

datumbox commented Jul 8, 2022

@ain-soph Thanks a lot for your contribution!

Concerning the linter, if you check the CI on tab Required lint modifications it will show you what's the problem:
image

Concerning your point the PatchMerging, I believe you are right. FX doesn't let you use the input of the tensor to control the flow of the program. This needs to move outside of the main function and be declared a non fx-traceable operator. I would patch this ASAP outside of this PR. cc @YosuaMichael

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ain-soph The PR looks in the right direction. I've added a few comments for your attention, please have a look.

Note that I'll be on PTO the next 2 weeks so @YosuaMichael offered to provide assistance. He is also a Meta maintainer so you can schedule a call if needed to speed this up.

Please note that my comments don't address the validity of the changes you do on the ML side but mostly on the idioms and coding practices we use at TorchVision. A more deep dive check would be necessary to confirm validity. The first step I think is to take the original pre-trained weights from the paper and confirm that we can load them in your implementation (by making some conversion) and then reproduce the reported accuracy using our reference scripts. Once we verify this, the next step is to train the tiny variant to show-case we can reproduce the accuracy. If you don't have access to a GPU cluster, we can help.

Again many thanks for your contribution and help on this.

cc @xiaohu2015 and @jdsgomes who worked on the original Swin implementation to see if they have any early feedback.

torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
@YosuaMichael
Copy link
Contributor

Hi @ain-soph, thanks a lot for the PR!
As of now, I am still reading SwinTransformerV2 paper and original code, and I will try to review afterwards.

Meanwhile, let me address some of the issue you raise:

  1. On ufmt issue, can you make sure you install the following version: pip install ufmt==1.3.2 black==21.9b0 usort==0.6.4 (reference)
  2. For the fx issue, I create a small patch: Small Patch SwinTransformer for FX compatibility #6252

Copy link
Contributor

@YosuaMichael YosuaMichael left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ain-soph , overall the PR looks good!
I agree with @datumbox comments that we may want to refactor some code so it can be reused in both V1 and V2 and I add some suggestion here.

torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
torchvision/models/swin_transformer.py Outdated Show resolved Hide resolved
self.register_buffer("relative_position_index", relative_position_index)

def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias_table: torch.Tensor = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
Copy link
Contributor Author

@ain-soph ain-soph Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use flatten(end_dim=-2) here instead of view(-1, self.num_heads)?

@ain-soph
Copy link
Contributor Author

ain-soph commented Aug 4, 2022

I've launched another train based on the most up-to-date code for tiny architecture.
Currently I don't observe any strange trend showing failure of convergence.

Epoch: [142]  [2500/2502]  eta: 0:00:00  lr: 0.0006042917608127198  img/s: 266.4080904318139  loss: 3.6911 (3.6364)  acc1: 48.4375 (48.2548)  acc5: 73.4375 (71.9340)  time: 0.4798  data: 0.0001  max mem: 14880
Epoch: [142] Total time: 0:20:10
Test:   [ 0/98]  eta: 0:06:25  loss: 1.5864 (1.5864)  acc1: 82.0312 (82.0312)  acc5: 96.0938 (96.0938)  time: 3.9314  data: 3.7073  max mem: 14880
Test:  Total time: 0:00:25
Test:  Acc@1 71.494 Acc@5 91.496
Test: EMA  [ 0/98]  eta: 0:05:43  loss: 1.4705 (1.4705)  acc1: 88.2812 (88.2812)  acc5: 97.6562 (97.6562)  time: 3.5061  data: 3.2585  max mem: 14880
Test: EMA Total time: 0:00:25
Test: EMA Acc@1 76.420 Acc@5 93.660 

@datumbox
Copy link
Contributor

datumbox commented Aug 4, 2022

@ain-soph Could you please share the exact command you use to train it to ensure we are not missing anything from our side?

@datumbox
Copy link
Contributor

datumbox commented Aug 4, 2022

Ooops, there are also memory issues on GPU both on linux and windows. I know that a different test is actually failing but this can happen. Sometimes adding a bing new model leads to issues on other models because of failing to clear the memory properly. Usually the way around this is to reduce the memory footprint of the test by passing thought smaller sizes for input or disabling the particular test on the GPU.

@ain-soph
Copy link
Contributor Author

ain-soph commented Aug 5, 2022

torchrun --nproc_per_node=4 train.py --model swin_v2_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256
Epoch: [221] Total time: 0:25:25
Test:   [ 0/98]  eta: 0:06:23  loss: 1.3831 (1.3831)  acc1: 91.4062 (91.4062)  acc5: 98.4375 (98.4375)  time: 3.9118 data: 3.6845  max mem: 14880
Test:  Total time: 0:00:25
Test:  Acc@1 77.738 Acc@5 94.460
Test: EMA  [ 0/98]  eta: 0:05:15  loss: 1.3172 (1.3172)  acc1: 90.6250 (90.6250)  acc5: 99.2188 (99.2188)  time: 3.2230 data: 2.9882  max mem: 14880
Test: EMA Total time: 0:00:25
Test: EMA Acc@1 80.012 Acc@5 95.286

@jdsgomes
Copy link
Contributor

jdsgomes commented Aug 5, 2022

torchrun --nproc_per_node=4 train.py --model swin_v2_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256
Epoch: [221] Total time: 0:25:25
Test:   [ 0/98]  eta: 0:06:23  loss: 1.3831 (1.3831)  acc1: 91.4062 (91.4062)  acc5: 98.4375 (98.4375)  time: 3.9118 data: 3.6845  max mem: 14880
Test:  Total time: 0:00:25
Test:  Acc@1 77.738 Acc@5 94.460
Test: EMA  [ 0/98]  eta: 0:05:15  loss: 1.3172 (1.3172)  acc1: 90.6250 (90.6250)  acc5: 99.2188 (99.2188)  time: 3.2230 data: 2.9882  max mem: 14880
Test: EMA Total time: 0:00:25
Test: EMA Acc@1 80.012 Acc@5 95.286

Thank you for sharing the commands. Although you previously mention I missed that in v2 they used a different resolution size. I launched new jobs to train all the variants with the correct resolution size and seems are looking better now.

@jdsgomes
Copy link
Contributor

jdsgomes commented Aug 9, 2022

@ain-soph thank you for your patience and great work! I managed to train all three variants and the results look good.

My plan is to update this PR in the next couple of ours with the model weights.

@jdsgomes
Copy link
Contributor

jdsgomes commented Aug 9, 2022

Trainning commands:

# swin_v2_t
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_v2_t --epochs 300 \
--batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  \
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
--lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 \
--amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
--random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
--ra-reps 4  --val-resize-size 256 --val-crop-size 256 --train-crop-size 256 

# swin_v2_s
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_v2_s --epochs 300 \
--batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  \
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
--lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
--random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
--ra-reps 4  --val-resize-size 256 --val-crop-size 256 --train-crop-size 256

# swin_v2_b
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_v2_b --epochs 300 \
--batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  \
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
--lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 \
--amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
--random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
--ra-reps 4  --val-resize-size 256 --val-crop-size 256 --train-crop-size 256 

Test commands and acuracies

# swin_v2_t
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py     \
    --model swin_v2_t --test-only --resume $EXPERIMENTS_PATH/44757/model_299.pth --interpolation bicubic  \
    --val-resize-size 260 --val-crop-size 256
# Test:  Acc@1 82.072 Acc@5 96.132

# swin_v2_s
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py     \
    --model swin_v2_s --test-only --resume $EXPERIMENTS_PATH/44758/model_299.pth --interpolation bicubic   \
    --val-resize-size $260--val-crop-size 256
# Test:  Acc@1 83.712 Acc@5 96.816

# swin_v2_b
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py     \
    --model swin_v2_b --test-only --resume $EXPERIMENTS_PATH/44759/model_299.pth --interpolation bicubic   \
    --val-resize-size 272 --val-crop-size 256
# Test:  Acc@1 84.112 Acc@5 96.864

@ain-soph
Copy link
Contributor Author

ain-soph commented Aug 9, 2022

I guess the only final thing on our plate is to fix the memory issue. I’ll work on it in the following week.

Btw, should we provide porting model weights from official repo as an alternative?

@datumbox
Copy link
Contributor

datumbox commented Aug 9, 2022

@ain-soph @jdsgomes Awesome work! Having SwinV2 in TorchVision is really awesome. Looking forward using them.

Btw, should we provide porting model weights from official repo as an alternative?

I think given we were able to reproduce the accuracy of the paper, there is no point offering both. What would be interesting on the future is to offer higher accuracy weights by using newer training recipes.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on Yosua previous review of the ML side and mine replication of the results approving. Thanks @ain-soph for you patience and great work.

As soon as the tests are green we we are good to merge

@jdsgomes jdsgomes changed the title [WIP] Add SwinV2 Add SwinV2 Aug 10, 2022
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed @jdsgomes 5 last commits and everything LGTM. I would be OK to merge on green CI (if the memory issue persists on the CI, we might need to turn off some of the tests).

@jdsgomes jdsgomes merged commit 5521e9d into pytorch:main Aug 10, 2022
@ain-soph ain-soph deleted the swin_transfomer_v2 branch August 10, 2022 21:39
@ain-soph ain-soph restored the swin_transfomer_v2 branch August 10, 2022 21:49
@ain-soph ain-soph deleted the swin_transfomer_v2 branch August 10, 2022 21:50
facebook-github-bot pushed a commit that referenced this pull request Aug 23, 2022
Summary:
* init submit

* fix typo

* support ufmt and mypy

* fix 2 unittest errors

* fix ufmt issue

* Apply suggestions from code review

* unify codes

* fix meshgrid indexing

* fix a bug

* fix type check

* add type_annotation

* add slow model

* fix device issue

* fix ufmt issue

* add expect pickle file

* fix jit script issue

* fix type check

* keep consistent argument order

* add support for pretrained_window_size

* avoid code duplication

* a better code reuse

* update window_size argument

* make permute and flatten operations modular

* add PatchMergingV2

* modify expect.pkl

* use None as default argument value

* fix type check

* fix indent

* fix window_size (temporarily)

* remove "v2_" related prefix and add v2 builder

* remove v2 builder

* keep default value consistent with official repo

* deprecate dropout

* deprecate pretrained_window_size

* fix dynamic padding edge case

* remove unused imports

* remove doc modification

* Revert "deprecate dropout"

This reverts commit 8a13f93.

* Revert "fix dynamic padding edge case"

This reverts commit 1c7579c.

* remove unused kwargs

* add downsample docs

* revert block default value

* revert argument order change

* explicitly specify start_dim

* add small and base variants

* add expect files and slow_models

* Add model weights and documentation for swin v2

* fix lint

* fix end of files line

Reviewed By: datumbox

Differential Revision: D38824237

fbshipit-source-id: d94082c210c26665e70bdf8967ef72cbe3ed4a8a

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Joao Gomes <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add SwinV2 in TorchVision
6 participants