Commit e5997d5
committed
Update on "Add base forward grad logic"
RFC: pytorch/rfcs#11
This PR add the basic logic to handle forward grad as dual Tensors.
It contains the following:
- Mechanism to save dual state on a Tensor and clear it up when the dual level ends
- C++ and python user facing API
- Updated view system that is able to track both forward and backward views
The current PR has the following limitations:
- Extensive tests are in the next PR in the stack as formulas are needed to write full tests.
- Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack)
- Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR.
- We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise.
- We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise.
Reading guide:
- Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d).
- New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325).
- Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677)
- API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243)
- c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9)
- python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d)
- python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8)
- c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3)
- Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433)
- Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030)
[ghstack-poisoned]File tree
313 files changed
+16465
-4665
lines changed- .circleci/scripts
- .jenkins/pytorch
- aten/src
- ATen
- core
- boxing/impl
- dispatch
- op_registration
- cpu/vec256
- vsx
- cuda
- miopen
- native
- cpu
- cuda
- cudnn
- metal/mpscnn/tests
- miopen
- mkl
- quantized
- cpu
- qnnpack
- sparse
- templates
- test
- THC/generic
- TH/generic
- benchmarks
- operator_benchmark/pt_extension
- static_runtime
- binaries
- c10
- core
- test/util
- util
- caffe2
- contrib/fakelowp/test
- core
- mobile/contrib/libvulkan-stub/include/vulkan
- opt
- quantization/server
- cmake
- External
- Modules
- docs/source
- ios
- modules/detectron
- test
- backward_compatibility
- cpp_extensions
- cpp
- jit
- tensorexpr
- distributed
- _pipeline/sync
- skip
- fx
- jit
- onnx
- quantization
- third_party
- tools
- autograd
- codegen
- api
- fast_nvcc
- jit
- pyi
- torch
- _C
- csrc
- autograd
- distributed/c10d
- jit
- api
- codegen/cuda
- frontend
- ir
- passes
- python
- runtime
- static
- serialization
- tensorexpr
- utils
- distributed
- _pipeline/sync
- balance
- skip
- algorithms/ddp_comm_hooks
- nn/api
- fx
- experimental
- lib/c10d
- test
- linalg
- nn
- intrinsic/quantized/modules
- modules
- onnx
- package
- quantization
- fx
- testing/_internal
- distributed
- nn/api
- utils
- data
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
313 files changed
+16465
-4665
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
38 | 44 | | |
39 | 45 | | |
40 | 46 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
41 | 40 | | |
42 | 41 | | |
43 | 42 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
24 | 25 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
544 | 544 | | |
545 | 545 | | |
546 | 546 | | |
| 547 | + | |
547 | 548 | | |
548 | 549 | | |
549 | 550 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
684 | 684 | | |
685 | 685 | | |
686 | 686 | | |
687 | | - | |
688 | | - | |
| 687 | + | |
| 688 | + | |
689 | 689 | | |
690 | 690 | | |
691 | 691 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
233 | 233 | | |
234 | 234 | | |
235 | 235 | | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
236 | 262 | | |
237 | 263 | | |
238 | 264 | | |
| |||
971 | 997 | | |
972 | 998 | | |
973 | 999 | | |
| 1000 | + | |
| 1001 | + | |
| 1002 | + | |
| 1003 | + | |
| 1004 | + | |
974 | 1005 | | |
975 | 1006 | | |
976 | 1007 | | |
| |||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
832 | 832 | | |
833 | 833 | | |
834 | 834 | | |
835 | | - | |
836 | | - | |
837 | | - | |
838 | | - | |
839 | | - | |
840 | | - | |
841 | | - | |
842 | | - | |
843 | | - | |
844 | | - | |
845 | | - | |
846 | | - | |
847 | | - | |
848 | | - | |
849 | | - | |
850 | | - | |
851 | | - | |
852 | | - | |
853 | | - | |
854 | | - | |
855 | | - | |
856 | | - | |
857 | | - | |
858 | | - | |
859 | | - | |
860 | | - | |
861 | | - | |
862 | | - | |
863 | | - | |
864 | | - | |
865 | | - | |
866 | | - | |
867 | | - | |
868 | | - | |
869 | | - | |
870 | | - | |
871 | | - | |
872 | | - | |
873 | | - | |
874 | | - | |
875 | | - | |
876 | | - | |
877 | | - | |
878 | | - | |
879 | | - | |
880 | | - | |
881 | | - | |
882 | 835 | | |
883 | 836 | | |
884 | 837 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | | - | |
42 | | - | |
43 | 41 | | |
44 | 42 | | |
45 | 43 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
75 | 75 | | |
76 | 76 | | |
77 | 77 | | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
78 | 90 | | |
0 commit comments