-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ViT to torchvision/models #4594
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @sallysyw, I didn't check the architecture for now. The code looks great, I just took a brief look at the docstrings and public/private interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @sallysyw, great addition.
I've added a few comments related to the conventions used at TorchVision. Let me know your thoughts. I'm happy to review the ML bit if you want, I just need to freshen up the paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a bunch of thoughts / suggestions. Feel free to ignore whatever doesn't make sense or to keep things simple!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, the only comments that might affect the API of the class are Francisco's remark on the classifier and Mannat's on embedding interpolation. The latter can be addressed on a follow up PR but it's worth considering the implications on the API. The good. thing is, we are just after a release, so even if we merge and want to review the API, we got plainly of time to do so. Other than that, I think the PR is mergeable. I left a couple of minor comments but none of them are blocking and mostly FYIs and nits.
@sallysyw I haven't checked the validity of the ViT implementation itself comparing to the paper. I know you are porting this from trusted sources, but if you want me to check it let me know.
This is a great point! If you train from scratch and the axes are messed up, you still get reasonable results sometimes (speaking from experience lol). We should try and maybe repro a result from here - https://github.com/facebookresearch/ClassyVision/tree/main/examples/vit Or if it's easier, we can use a pretrained model from any source (like Classy), and evaluate it with this implementation and verify that the accuracy matches! |
Thanks @mannatsingh, I've been training vit_b_32 from scratch on AWS cluster, once the training is finished and the results look fine, I'll update it here. |
I've finished the first iteration of training the
I plan to further tune the parameters and I'll upload the pre-trained weights in a following PR once I got the accuracy numbers matching the previous results. Let me know if there's other concerns before I can merge this PR. |
Do you know where can I find some pre-trained classy-vision checkpoints? |
@sallysyw Thanks for the update. As far as I can see the unit-tests are failing and it seems related. From what I can see the FX feature extractor seg faults on vit_b_16. This needs to be fixed before we merge:
This is a good idea to do. I don't know where you could find one (worth checking with internal teams to see if they have checkpoints you can use) but it's definitely a necessary step prior merging. I would also recommend, unless you have other important reasons, to fully reproduce the achieved accuracies before merging. In the past, there were cases were the cause of the drop in accuracy was identified in the architecture itself. Though you can mitigate that risk by loading pre-trained checkpoints and reproducing their accuracies to the digit, this doesn't cover for bugs for when the model is on training mode.
Some of the above comments, for example Francisco's #4594 (comment) and #4594 (comment), affect the architecture so if you merge now you might have to make radical changes later. Given that the next release is months from now, we got time to fix them. Still given that there are outstanding comments, it would be good to confirm with @mannatsingh and @fmassa that an early merge is OK in this case. Another option is to merge this on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All look great to me. I left one last comment for something missed in earlier reviews and 2 optional nits. Hopefully this is the last round before merging,
whoops - trying to import
But I confirmed locally that after the changes the tests passed on linux cpu. cc @datumbox |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sallysyw I tested them locally and the tests fail:
[W ir_emitter.cpp:4340] Warning: Dict values consist of heterogeneous types, which means that the dict has been typed as containing Dict[str, Union[Tensor, Tuple[Tensor, Optional[Tensor]]]]. To use any of the values in this Dict, it will be necessary to add an `assert isinstance` statement before first use to trigger type refinement.
File "<eval_with_key>.76", line 196
eq_14 = dim_12 == 3; dim_12 = None
_assert_14 = torch._assert(eq_14, 'Expected (seq_length, batch_size, hidden_dim) got Proxy(getattr_14)'); eq_14 = None
return {'encoder.dropout': encoder_dropout, 'encoder.layers.encoder_layer_5.ln': encoder_layers_encoder_layer_5_ln_1, 'encoder.layers.encoder_layer_6.ln': encoder_layers_encoder_layer_6_ln_1, 'encoder.layers.encoder_layer_7.add': add_15, 'encoder.layers.encoder_layer_7.add_1': add_16, 'encoder.layers.encoder_layer_8.ln_1': encoder_layers_encoder_layer_8_ln_2, 'encoder.layers.encoder_layer_8.mlp.linear_1': encoder_layers_encoder_layer_8_mlp_linear_2, 'encoder.layers.encoder_layer_10.self_attention': encoder_layers_encoder_layer_10_self_attention, 'encoder.layers.encoder_layer_10.add': add_21, 'encoder.layers.encoder_layer_10.mlp.dropout_1': encoder_layers_encoder_layer_10_mlp_dropout_2}
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
(function operator())
test/test_backbone_utils.py:187 (TestFxFeatureExtraction.test_jit_forward_backward[vit_b_16])
Traceback (most recent call last):
File "./vision/test/test_backbone_utils.py", line 198, in test_jit_forward_backward
sum(o.mean() for o in fgn_out.values()).backward()
File "./vision/test/test_backbone_utils.py", line 198, in <genexpr>
sum(o.mean() for o in fgn_out.values()).backward()
AttributeError: 'tuple' object has no attribute 'mean'
I think we moved the code to prototype to early (that's on me) before we confirm that all issues are fixed. Since the tests on prototypes are not executed on CI, we now risk not detecting issues with the implementation.
I propose undoing the move to prototype, running the tests on the CI and ensure everything works prior moving everything to the prototype again. I did this job on a separate no-merge PR at #4984 and as you can see the tests fail. You are welcome to merge my changes into your current PR to investigate.
hmm... I think the seed generated on your machine is different from mine and that's why I didn't catch this failure previously. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sallysyw All worked like a charm. I pushed the removal of the extra classed from your branch and merged main. We should be good to merge whenever you want. Thanks for the great contribution, looking forward to the weights.
Hey @sallysyw! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: * [vit] Adding ViT to torchvision/models * adding pre-logits layer + resolving comments * Fix the model attribute bug * Change version to arch * fix failing unittests * remove useless prints * reduce input size to fix unittests * Increase windows-cpu executor to 2xlarge * Use `batch_first=True` and remove classifier * Change resource_class back to xlarge * Remove vit_h_14 * Remove vit_h_14 from __all__ * Move vision_transformer.py into prototype * Fix formatting issue * remove arch in builder * Fix type err in model builder * address comments and trigger unittests * remove the prototype import in torchvision.models * Adding vit back to models to trigger CircleCI test * fix test_jit_forward_backward * Move all to prototype. * Adopt new helper methods and fix prototype tests. * Remove unused import. Reviewed By: NicolasHug Differential Revision: D32694316 fbshipit-source-id: fa2867555fb7ae65f8dab537517386f6694585a2 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
The first part of #4593 :)
cc @datumbox @bjuncek