-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add forward_gradients api and enable high-order differentiation for J…
…acobian/Hessian (#43354) * enable Jacobian,Hessian supporting new autograd * fix prim mode failed in PR-CI-Windows * add forward_gradients api * add forward_gradients api * skip test_autograd_functional_prim in windows ci * fix test_autograd_funciton_prim timeouot * remove the block parameter in prim2orig method * remove duplicate to_tensors code snippet # test=allcases
- Loading branch information
Showing
12 changed files
with
496 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import typing | ||
|
||
from paddle.fluid import framework | ||
|
||
|
||
def as_tensors(xs): | ||
if isinstance(xs, framework.Variable): | ||
return (xs, ) | ||
elif isinstance(xs, typing.Sequence): | ||
return tuple(xs) | ||
else: | ||
return xs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_prim.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import typing | ||
import unittest | ||
|
||
import numpy as np | ||
import paddle | ||
|
||
import config | ||
import utils | ||
|
||
|
||
@utils.place(config.DEVICES) | ||
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), ( | ||
('unary_float32', paddle.tanh, (np.random.rand(2, 3), ), 'float32'), | ||
('binary_float32', paddle.matmul, | ||
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float32'), | ||
('unary_float64', paddle.tanh, (np.random.rand(2, 3), ), 'float64'), | ||
('binary_float64', paddle.matmul, | ||
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float64'), | ||
)) | ||
class TestJacobianPrim(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.args = [arg.astype(cls.dtype) for arg in cls.args] | ||
cls._rtol = config.TOLERANCE.get( | ||
cls.dtype).get('first_order_grad').get('rtol') | ||
cls._atol = config.TOLERANCE.get( | ||
cls.dtype).get('first_order_grad').get('atol') | ||
|
||
def setUp(self): | ||
paddle.enable_static() | ||
paddle.incubate.autograd.enable_prim() | ||
|
||
def tearDown(self): | ||
paddle.incubate.autograd.disable_prim() | ||
paddle.disable_static() | ||
|
||
def test_jacobian_prim(self): | ||
|
||
def wrapper(fun, args): | ||
mp = paddle.static.Program() | ||
sp = paddle.static.Program() | ||
with paddle.static.program_guard(mp, sp): | ||
static_args = [ | ||
paddle.static.data(f'arg{i}', arg.shape, self.dtype) | ||
for i, arg in enumerate(args) | ||
] | ||
for arg in static_args: | ||
arg.stop_gradient = False | ||
jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:] | ||
if paddle.incubate.autograd.prim_enabled(): | ||
paddle.incubate.autograd.prim2orig() | ||
exe = paddle.static.Executor() | ||
exe.run(sp) | ||
[jac] = exe.run(mp, | ||
feed={f'arg{i}': arg | ||
for i, arg in enumerate(args)}, | ||
fetch_list=[jac]) | ||
return jac | ||
|
||
paddle.incubate.autograd.enable_prim() | ||
prim_jac = wrapper(self.fun, self.args) | ||
paddle.incubate.autograd.disable_prim() | ||
orig_jac = wrapper(self.fun, self.args) | ||
|
||
np.testing.assert_allclose(orig_jac, | ||
prim_jac, | ||
rtol=self._rtol, | ||
atol=self._atol) | ||
|
||
|
||
@utils.place(config.DEVICES) | ||
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), ( | ||
('unary_float32', paddle.tanh, (np.random.rand(1), ), 'float32'), | ||
('binary_float32', paddle.multiply, | ||
(np.random.rand(1), np.random.rand(1)), 'float32'), | ||
('unary_float64', paddle.tanh, (np.random.rand(1), ), 'float64'), | ||
('binary_float64', paddle.multiply, | ||
(np.random.rand(1), np.random.rand(1)), 'float64'), | ||
)) | ||
class TestHessianPrim(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.args = [arg.astype(cls.dtype) for arg in cls.args] | ||
cls._rtol = config.TOLERANCE.get( | ||
cls.dtype).get('second_order_grad').get('rtol') | ||
cls._atol = config.TOLERANCE.get( | ||
cls.dtype).get('second_order_grad').get('atol') | ||
|
||
def setUp(self): | ||
paddle.enable_static() | ||
paddle.incubate.autograd.enable_prim() | ||
|
||
def tearDown(self): | ||
paddle.incubate.autograd.disable_prim() | ||
paddle.disable_static() | ||
|
||
def test_jacobian_prim(self): | ||
|
||
def wrapper(fun, args): | ||
mp = paddle.static.Program() | ||
sp = paddle.static.Program() | ||
with paddle.static.program_guard(mp, sp): | ||
static_args = [ | ||
paddle.static.data(f'arg{i}', arg.shape, self.dtype) | ||
for i, arg in enumerate(args) | ||
] | ||
for arg in static_args: | ||
arg.stop_gradient = False | ||
hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:] | ||
if paddle.incubate.autograd.prim_enabled(): | ||
paddle.incubate.autograd.prim2orig() | ||
exe = paddle.static.Executor() | ||
exe.run(sp) | ||
[hessian | ||
] = exe.run(mp, | ||
feed={f'arg{i}': arg | ||
for i, arg in enumerate(args)}, | ||
fetch_list=[hessian]) | ||
return hessian | ||
|
||
paddle.incubate.autograd.enable_prim() | ||
prim_jac = wrapper(self.fun, self.args) | ||
paddle.incubate.autograd.disable_prim() | ||
orig_jac = wrapper(self.fun, self.args) | ||
|
||
np.testing.assert_allclose(orig_jac, | ||
prim_jac, | ||
rtol=self._rtol, | ||
atol=self._atol) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.