Skip to content

Commit

Permalink
fix duplicate slice logic in _grad (#44396)
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly authored Jul 18, 2022
1 parent 4c1e77d commit 1d12832
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ def test_all_cases(self):
self.func_vjp_nested()
self.func_vjp_aliased_input()

def test_input_single_tensor(self):
self.assertIsInstance(
paddle.incubate.autograd.vjp(paddle.tanh, paddle.rand((3, 4)))[1],
paddle.fluid.framework.Variable)


@utils.place(config.DEVICES)
@utils.parameterize(
Expand Down
10 changes: 6 additions & 4 deletions python/paddle/incubate/autograd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,15 @@ def _grad(ys, xs, v=None):
inputs.
"""
if paddle.fluid._non_static_mode():
# paddle.grad returns a list though the inputs is a signle Tensor. The
# follow code snippet fixes the problem by return the first element of
# xs_grad when the xs is a signle Tensor.
xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True)
if isinstance(xs, paddle.fluid.framework.Variable) and isinstance(
xs_grad, typing.Sequence) and len(xs_grad) > 0:
xs_grad = xs_grad[0]
else:
xs_grad = paddle.incubate.autograd.grad(ys, xs, v)

if isinstance(xs, paddle.fluid.framework.Variable):
xs_grad = xs_grad[0]

return _replace_none_with_zero_tensor(xs_grad, xs)


Expand Down
11 changes: 9 additions & 2 deletions python/paddle/incubate/autograd/primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ def grad(outputs, inputs, grad_outputs=None):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
"""

if not utils.prim_enabled():
return backward.gradients(outputs, inputs, grad_outputs)
grad_inputs = backward.gradients(outputs, inputs, grad_outputs)
# backward.gradients returns a list though the inputs is a signle Tensor.
# The follow code snippet fixes the problem by return the first element
# of grad_inputs when the inputs is a signle Tensor.
if isinstance(inputs, framework.Variable) and isinstance(
grad_inputs, typing.Sequence) and len(grad_inputs) > 0:
return grad_inputs[0]
else:
return grad_inputs

if not isinstance(outputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], '
Expand Down

0 comments on commit 1d12832

Please sign in to comment.