Skip to content

Commit

Permalink
Add forward_gradients api and enable high-order differentiation for J…
Browse files Browse the repository at this point in the history
…acobian/Hessian (#43354)

* enable Jacobian,Hessian supporting new autograd

* fix prim mode failed in PR-CI-Windows

* add forward_gradients api

* add forward_gradients api

* skip test_autograd_functional_prim in windows ci

* fix test_autograd_funciton_prim timeouot

* remove the block parameter in prim2orig method

* remove duplicate to_tensors code snippet # test=allcases
  • Loading branch information
cxxly authored Jun 28, 2022
1 parent 82cd8d2 commit a97a8dd
Show file tree
Hide file tree
Showing 12 changed files with 496 additions and 82 deletions.
51 changes: 29 additions & 22 deletions python/paddle/autograd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import paddle
from paddle.fluid import framework
from paddle.autograd.utils import as_tensors


def vjp(func, xs, v=None):
Expand Down Expand Up @@ -346,10 +347,16 @@ class _Jacobian(object):
"""

def __init__(self, func, xs):
self._xs = _separate(xs)
self._ys = func(*_as_tensors(self._xs))
self._flatten_xs = self._flatten(_as_tensors(self._xs))
self._flatten_ys = self._flatten(_as_tensors(self._ys))
# Skip separating in prim mode temporarily, as detach and clone are not
# primitive operators.
if not paddle.fluid._non_static_mode(
) and paddle.incubate.autograd.prim_enabled():
self._xs = xs
else:
self._xs = _separate(xs)
self._ys = func(*as_tensors(self._xs))
self._flatten_xs = self._flatten(as_tensors(self._xs))
self._flatten_ys = self._flatten(as_tensors(self._ys))
self._cache = {}

@property
Expand Down Expand Up @@ -385,9 +392,13 @@ def __getitem__(self, indexes):
return self._cached_evaluate(
indexes[self._lazy_axis])[other_indexes]
lazy_indexes = self._lazy_indexes(indexes)
part_jac = paddle.stack(
# Using concat and reshape to replace stack operator temporarily, as
# it is not a primitive operator.
shape = list(self.shape)
shape[self._lazy_axis] = len(lazy_indexes)
part_jac = paddle.concat(
[self._cached_evaluate(i) for i in lazy_indexes],
axis=self._lazy_axis)
axis=self._lazy_axis).reshape(shape)
return part_jac[self._shifted_indexes(indexes, len(lazy_indexes))]

def _cached_evaluate(self, k):
Expand Down Expand Up @@ -449,7 +460,7 @@ def _lazy_axis(self):

def _flatten(self, xs):
return paddle.concat(
tuple(x.reshape((-1, x.shape[-1])) for x in _as_tensors(xs)), 0)
tuple(x.reshape((-1, x.shape[-1])) for x in as_tensors(xs)), 0)

def _evaluate(self, row):
return self._flatten(_grad(self._flatten_ys[row, :], self._xs))
Expand All @@ -475,7 +486,7 @@ def _lazy_axis(self):

def _flatten(self, xs):
return paddle.concat(
tuple(x.reshape((x.shape[0], -1)) for x in _as_tensors(xs)), 1)
tuple(x.reshape((x.shape[0], -1)) for x in as_tensors(xs)), 1)

def _evaluate(self, row_index):
return self._flatten(_grad(self._flatten_ys[:, row_index], self._xs))
Expand Down Expand Up @@ -526,10 +537,6 @@ def _multi_index(indexes, shape):
return tuple(positive_indexes)


def _as_tensors(xs):
return (xs, ) if isinstance(xs, framework.Variable) else xs


def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(origin_list, axis=0) if isinstance(
Expand Down Expand Up @@ -683,7 +690,7 @@ def _check_v_shape(v, refs):
if v is None:
return

v, refs = _as_tensors(v), _as_tensors(refs)
v, refs = as_tensors(v), as_tensors(refs)
if len(refs) != len(v):
raise RuntimeError(f"The argument v is a tuple of invalid length:"
f"should be {len(refs)} but got {len(v)}.")
Expand Down Expand Up @@ -805,8 +812,8 @@ def func(x, y):
# [0., 0., 0., 2.]]), None))
'''
inputs = _as_tensors(inputs)
outputs = _as_tensors(func(*inputs))
inputs = as_tensors(inputs)
outputs = as_tensors(func(*inputs))
fin_size = len(inputs)
fout_size = len(outputs)
flat_outputs = tuple(
Expand Down Expand Up @@ -942,8 +949,8 @@ def func(x, y):
'''

inputs = _as_tensors(inputs)
outputs = _as_tensors(func(*inputs))
inputs = as_tensors(inputs)
outputs = as_tensors(func(*inputs))

batch_size = inputs[0].shape[0]
for input in inputs:
Expand Down Expand Up @@ -1103,7 +1110,7 @@ def func(x, y):
# [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None))
'''
inputs = _as_tensors(inputs)
inputs = as_tensors(inputs)
outputs = func(*inputs)
batch_size = inputs[0].shape[0]
for input in inputs:
Expand Down Expand Up @@ -1234,7 +1241,7 @@ def func(x, y):
# [0., 1., 1., 2.]]), None), (None, None))
'''
inputs = _as_tensors(inputs)
inputs = as_tensors(inputs)
outputs = func(*inputs)
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
1
Expand Down Expand Up @@ -1339,12 +1346,12 @@ def func(x, y):
# [[8., 8.],
# [8., 8.]]), None])
'''
xs = _as_tensors(inputs)
xs = as_tensors(inputs)
if v is not None:
v = _as_tensors(v)
v = as_tensors(v)
xs, v = _separate(xs), _separate(v)
outputs = func(*xs)
ys = _as_tensors(outputs)
ys = as_tensors(outputs)
assert len(ys) == 1 and isinstance(
ys[0], framework.Variable
) and ys[0].shape == [
Expand Down
26 changes: 26 additions & 0 deletions python/paddle/autograd/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from paddle.fluid import framework


def as_tensors(xs):
if isinstance(xs, framework.Variable):
return (xs, )
elif isinstance(xs, typing.Sequence):
return tuple(xs)
else:
return xs
11 changes: 10 additions & 1 deletion python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ file(
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)

if(WIN32)
# TODO: Fix these unittests failed on Windows
list(REMOVE_ITEM TEST_OPS test_autograd_functional_prim)
list(REMOVE_ITEM TEST_OPS test_primapi)
endif()

foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 200)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60)
if(NOT WIN32)
set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle.fluid as fluid
import paddle.compat as cpt
import paddle.nn.functional as F
from paddle.autograd.functional import _as_tensors
from paddle.autograd.utils import as_tensors
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check

import config
Expand All @@ -33,7 +33,7 @@


def make_v(f, inputs):
outputs = _as_tensors(f(*inputs))
outputs = as_tensors(f(*inputs))
return [paddle.ones_like(x) for x in outputs]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
import unittest

import numpy as np
import paddle

import config
import utils


@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), (
('unary_float32', paddle.tanh, (np.random.rand(2, 3), ), 'float32'),
('binary_float32', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float32'),
('unary_float64', paddle.tanh, (np.random.rand(2, 3), ), 'float64'),
('binary_float64', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float64'),
))
class TestJacobianPrim(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.args = [arg.astype(cls.dtype) for arg in cls.args]
cls._rtol = config.TOLERANCE.get(
cls.dtype).get('first_order_grad').get('rtol')
cls._atol = config.TOLERANCE.get(
cls.dtype).get('first_order_grad').get('atol')

def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()

def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()

def test_jacobian_prim(self):

def wrapper(fun, args):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_args = [
paddle.static.data(f'arg{i}', arg.shape, self.dtype)
for i, arg in enumerate(args)
]
for arg in static_args:
arg.stop_gradient = False
jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:]
if paddle.incubate.autograd.prim_enabled():
paddle.incubate.autograd.prim2orig()
exe = paddle.static.Executor()
exe.run(sp)
[jac] = exe.run(mp,
feed={f'arg{i}': arg
for i, arg in enumerate(args)},
fetch_list=[jac])
return jac

paddle.incubate.autograd.enable_prim()
prim_jac = wrapper(self.fun, self.args)
paddle.incubate.autograd.disable_prim()
orig_jac = wrapper(self.fun, self.args)

np.testing.assert_allclose(orig_jac,
prim_jac,
rtol=self._rtol,
atol=self._atol)


@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), (
('unary_float32', paddle.tanh, (np.random.rand(1), ), 'float32'),
('binary_float32', paddle.multiply,
(np.random.rand(1), np.random.rand(1)), 'float32'),
('unary_float64', paddle.tanh, (np.random.rand(1), ), 'float64'),
('binary_float64', paddle.multiply,
(np.random.rand(1), np.random.rand(1)), 'float64'),
))
class TestHessianPrim(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.args = [arg.astype(cls.dtype) for arg in cls.args]
cls._rtol = config.TOLERANCE.get(
cls.dtype).get('second_order_grad').get('rtol')
cls._atol = config.TOLERANCE.get(
cls.dtype).get('second_order_grad').get('atol')

def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()

def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()

def test_jacobian_prim(self):

def wrapper(fun, args):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_args = [
paddle.static.data(f'arg{i}', arg.shape, self.dtype)
for i, arg in enumerate(args)
]
for arg in static_args:
arg.stop_gradient = False
hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:]
if paddle.incubate.autograd.prim_enabled():
paddle.incubate.autograd.prim2orig()
exe = paddle.static.Executor()
exe.run(sp)
[hessian
] = exe.run(mp,
feed={f'arg{i}': arg
for i, arg in enumerate(args)},
fetch_list=[hessian])
return hessian

paddle.incubate.autograd.enable_prim()
prim_jac = wrapper(self.fun, self.args)
paddle.incubate.autograd.disable_prim()
orig_jac = wrapper(self.fun, self.args)

np.testing.assert_allclose(orig_jac,
prim_jac,
rtol=self._rtol,
atol=self._atol)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit a97a8dd

Please sign in to comment.