Skip to content

Commit

Permalink
setitem support passing stop_gradient from value to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Nov 8, 2021
1 parent ab2004b commit 3d40b53
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,12 @@ void BindImperative(py::module *m_ptr) {
auto value_tensor =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
ins.insert({"ValueTensor", {value_tensor}});

// pass the stop_gradient from value to tensor
if (!value_tensor->OverridedStopGradient() &&
self->OverridedStopGradient()) {
self->SetOverridedStopGradient(false);
}
} else if (py::isinstance<py::array>(value_obj)) {
auto value_tensor = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(false,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,18 @@ def set_value5(t, value):
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))

# case 6: pass stop_gradient from value to x
x = paddle.zeros([8, 8], dtype='float32')
value = paddle.to_tensor([10], dtype='float32', stop_gradient=False)

self.assertTrue(x.stop_gradient)
self.assertTrue(x.is_leaf)

x[0, :] = value

self.assertTrue(~x.stop_gradient)
self.assertTrue(~x.is_leaf)

def test_static_graph(self):
paddle.enable_static()

Expand Down

0 comments on commit 3d40b53

Please sign in to comment.