-
Notifications
You must be signed in to change notification settings - Fork 23k
PyTorch ONNX exporter
Page Maintainers: @abock, @justinchuby
Documentation for developing the PyTorch-ONNX exporter (torch.onnx
). For an index of all ONNX exporter related topics, see PyTorch ONNX Topics
We highly recommend using Linux. Other platforms are not tested in PyTorch CI and
are generally not used by the torch.onnx
developers.
Fork github.com/pytorch/pytorch and clone your fork to your workstation.
Run
git submodule update --init --recursive --jobs 0
CUDA is not required for most development tasks. If you use CUDA, building PyTorch will probably be slower.
Install Anaconda and activate a new environment.
Install direnv and initialize your envrc file in the root of your PyTorch git repo:
NOTE: Please remember to hook installation after you install direnv.
# Make the local package name built by `setup.py develop` the same
# as the one that's on conda.
echo "export TORCH_PACKAGE_NAME=pytorch" >> .envrc
# Let CMake find binaries and libs installed by conda.
echo 'export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}' >> .envrc
direnv allow
On Azure Linux:
sudo dnf install glibc-devel kernel-headers
Then see the instructions in PyTorch's README.
Use direnv for Anaconda environment selection.
Set more environment variables in your .envrc file:
# Only if you're building without CUDA.
export USE_CUDA=0
# Only if you're building with ccache.
PATH_add /usr/lib/ccache
# Needed for older compilers or conda compilers
export LDFLAGS='-lrt'
# Build with debug symbols.
export DEBUG=1
Install the dependencies required for development and to run CI checks locally.
pip install expecttest pytest parameterized flake8 hypothesis pytest-cov pytest-xdist pytest-subtest pylint lintrunner ghstack beartype
lintrunner init
Read more about:
- lintrunner (required): Run all the linters, and ensures consistency between the CI and local development environments.
- ghstack (optional): Conveniently submit stacks of diffs to GitHub as separate pull requests. NOTE: GitLens's interactive rebase feature comes in handy with ghstack.
- To recover your branch from ghstack:
ghstack checkout github_link_to_pr
pip install onnxruntime onnx
The ONNX tests depend on torchvision. This is tricky because TorchVision depends on PyTorch, but we don't want our package manager to install PyTorch, we want to use our locally built one.
# If you're not using CUDA, use the command below. If you are, see https://pytorch.org/get-started/locally/
pip install --upgrade --no-deps --pre torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# manually install torchvision deps
You should be able to run these commands successfully:
python setup.py develop
pytest -svk test_arithmetic_prim_long test/onnx/test_pytorch_onnx_onnxruntime.py
And this should fail:
echo "assert False" >> torch/onnx/utils.py
pytest -svk test_arithmetic_prim_long test/onnx/test_pytorch_onnx_onnxruntime.py
git restore torch/onnx/utils.py
If the second command succeeds, then probably python is finding a PyTorch that was installed via conda
or pip
, not the one that was built from source by python setup.py develop
.
You can place this recommended settings.json
under .vscode/
{
// Always remove trailing whitespaces
"files.trimTrailingWhitespace": true,
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,
"[python]": {
"editor.tabSize": 4,
// Set to true to auto sort imports
"editor.codeActionsOnSave": {
"source.organizeImports": false
},
"editor.rulers": [88],
},
// Enable Python linting and Pylance type checking
"python.analysis.typeCheckingMode": "basic",
"python.formatting.provider": "black",
"python.sortImports.args": ["--profile", "black"],
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
"python.linting.pydocstyleEnabled": true,
"python.linting.pydocstyleArgs": ["--convention=google"],
"python.linting.banditEnabled": true,
"python.linting.pylintEnabled": true,
"python.linting.pylintArgs": [
"--disable=no-member"
]
}
Recommended extensions (you can install them in the "Extensions" tab)
{
"recommendations": [
// Python
"ms-python.python",
"ms-python.vscode-pylance",
"njpwerner.autodocstring",
// Markdown
"yzhang.markdown-all-in-one",
"DavidAnson.vscode-markdownlint",
// Coding style
"shardulm94.trailing-spaces",
// Linting display
"usernamehw.errorlens",
"igorsbitnev.error-gutters",
"ryanluker.vscode-coverage-gutters",
// Github review integration
"GitHub.vscode-pull-request-github",
"eamodio.gitlens",
// Show changed files between branches
"letmaik.git-tree-compare",
]
}
If you use error lens, I recommend the following settings
{
"errorLens.excludeBySource": [
"cSpell" // Exclude noisy spelling errors
],
"errorLens.followCursor": "closestProblem",
"errorLens.fontSize": "0.9em", // Smaller unintrusive messages
"errorLens.followCursorMore": 3, // Hide errors too far away from the cursor
}
You can set up VS Code to run gdb
and set breakpoints when debugging c++ code. In launch.json
add the configuration
// ...
"configurations": [
{
"name": "(gdb) Launch",
"type": "cppdbg",
"request": "launch",
"program": "<path to python bin>",
"args": [
"-m",
"pytest",
"<test file and test name in pytest format>"
],
"stopAtEntry": false,
"cwd": "path/to/repo/of/pytorch",
"environment": [],
"externalConsole": false,
"MIMode": "gdb",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"text": "-enable-pretty-printing",
"ignoreFailures": true
},
{
"description": "Set Disassembly Flavor to Intel",
"text": "-gdb-set disassembly-flavor intel",
"ignoreFailures": true
}
]
},
]
You can then set breakpoints in the c++ source and run the debugger in VS Code.
PRs should be opened directly against main. PRs can be directly merged into main long as it satisfies the ONNX merge rule:
- Approved by one of torch.onnx developers listed in
approved_by
section. - All modified files fall under the
patterns
section.
Pay special attention to the following GitHub checks:
- Has "onnx" in the name, which runs ONNX related tests.
- Has "Lint" in the name, which does code format checks.
Regarding other failing checks: if you are certain the failure is unrelated to your change, try rebasing on main. Often these failures are caused by a branch being out of sync with main. You can ignore the failing check if it is a regression in main. This can be verified by checking if main is also failing from CI HUD.
To merge your pull request, comment on the PR @pytorchbot merge
. (doc Bot commands)
If you make changes to non-ONNX related code, i.e. files outside of ONNX merge rule, please note the PR will require additional reviews from people outside of torch.onnx developers, and will take a longer process to merge into main. In this case, pytorchbot will not be able to merge the pull request. It will leave a comment like "Merge failed due to PR XXX does not match merge rules". Please label the pull request with onnx-needs-import
.
See GitHub pull request workflow.
Adhere to Google's Code Review Developer Guide and PyTorch Code review values
Running all the tests locally takes a very long time, so generally you should run a few tests locally and rely on
GitHub CI checks for comprehensive testing.
We highly recommend using pytest to run tests selectively.
Note that you should use python -m pytest
rather than calling pytest
directly to make sure it uses your locally
built version of PyTorch.
Most relevant tests are in test/onnx/.
The most used test file is test_pytorch_onnx_onnxruntime.py. The tests in this file generally:
- Define a subclass of
torch.nn.Module
. - Define some inputs.
- Call
self.run_test()
with the instantiated module and inputs.
run_test()
converts the module to ONNX and compares the output between PyTorch and ONNX Runtime.
Tests added to TestONNXRuntime
are automatically defined for all supported opset versions. Use the -k
option in pytest to run the test you want.
For example:
# run the `test_quantized_arithmetic_qfunctional` test
python -m pytest test/onnx/test_pytorch_onnx_onnxruntime.py -k test_quantized_arithmetic_qfunctional
An example of adding unit tests for a new symbolic function: Add binary_cross_entropy_with_logits op
You can use pytest
to run tests in parallel and generate a coverage report.
python -m pytest -n auto --cov --cov-report "xml:test/coverage.xml" test/onnx/test_pytorch_onnx_onnxruntime.py
Set the environment variable TORCH_LOGS="onnx_diagnostics"
to capture detailed diagnostics.
- User-facing doc: docs/source/onnx.rst
- Python tests: test/onnx/
- More Python tests: test/jit/test_onnx_export.py
- Python code: torch/onnx/
- C++ code: torch/csrc/jit/passes/onnx/
https://github.com/pytorch/pytorch/issues/116684
Pre-dispatch may skip functionalization
To support quantized model export, we need to unpack the quantized tensor inputs and the PackedParam weights (https://github.com/pytorch/pytorch/pull/69232). We construct through TupleConstruct
to have a 1-to-1 input mapping,
so that we can use replaceAllUsesWith
API for its successors. In addition, we support quantized namespace export, and the developers can add more symbolics for quantized operators conveniently in the current framework.
Can be updated in https://github.com/pytorch/pytorch/blob/main/.ci/docker/common/install_onnx.sh