-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Model Architectures - Implementation and Documentation Details #5319
Comments
I've pinned the issue for now but we should consider making use of the Wiki pages, which are better suited for this kind of content. PyTorch core uses them extensively, so it might be worth aligning. |
Are we expecting these guidelines to change often? |
That is a good point, and has been discussed previously with @datumbox and also during this PR review. In short I think is fair so say that there are no strong feelings either way, but there were two main arguments to keep it in a ticket for now. First we didn't want to make the contribution guidelines too long, second the content can change. So I would still favour to keep it here for a while and if it seems stable enough we can move it to a |
Things will eventually become stable. I think the biggest changes on this documentation will come from the following:
Once things stabilise we can move to |
Thanks for this great guide @jdsgomes. There is a small typo here: ConvNormActication |
🚀 The feature
When adding a new model architecture there are some design/implementation details and documentation requirements that need to be taken into account. This issue intents to track such details in a dynamic manner, as it is possible to change over time.
Motivation, pitch
New Model Architectures - Implementation Details
Model development and training steps
When developing a new model there are some details not to be missed:
Implement a model factory function for each of the model variants
in the module constructor, pass layer constructor instead of instance for configurable layers like norm, activation, and log the api usage with
_log_api_usage_once(self)
fuse layers together with existing common blocks if possible; For example consecutive conv, bn, activation layers could be replaced by ConvNormActivation
define
__all__
in the beginning of the model file to expose model factory functions; import model public APIs (e.g. factory methods) intorchvision/models/__init__.py
create the model builder using the new API and add it to the prototype area. Here is an example on how to do this. The new API requires adding more information about the weights such as the preprocessing transforms necessary for using the model, meta-data about the model, etc
Make sure you write tests for the model itself (see
_check_input_backprop
,_model_params
and_model_params
intest/test_models.py
) and for any new operators/transforms or important functions that you introducethe new model should be torch scriptable (using
torch.jit.script
)the new model should be fx compatible (using
torch.fx.symbolic_trace
)Note that this list is not exhaustive and there are details here related to the code quality etc, but these are rules that apply in all PRs (see Contributing to TorchVision).
Once the model is implemented, you need to train the model using the reference scripts. For example, in order to train a classification resnet18 model you would:
go to
references/classification
run the train command (for example
torchrun --nproc_per_node=8 train.py --model resnet18
)After training the model, select the best checkpoint and estimate its accuracy with a batch size of 1 on a single GPU. This helps us get better measurements about the accuracy of the models and avoid variants introduced due to batch padding (read here for more details).
Finally, run the model test to generate expected model files for testing. Please include those generated files in the PR as well.:
EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k {model_name}
Documentation and Pytorch Hub
docs/source/models.rst
:add the model to the corresponding section (classification/detection/video etc.)
describe how to construct the model variants (with and without pre-trained weights)
add model metrics and reference to the original paper
hubconf.py
:import the model factory functions
submit a PR to https://github.com/pytorch/hub with a model page (or update an existing one)
README.md
under the reference script folder:Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: