diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 8b01f02ee2c3a..4403eb469723a 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -985,6 +985,12 @@ void BindImperative(py::module *m_ptr) { auto value_tensor = value_obj.cast>(); ins.insert({"ValueTensor", {value_tensor}}); + + // pass the stop_gradient from value to tensor + if (!value_tensor->OverridedStopGradient() && + self->OverridedStopGradient()) { + self->SetOverridedStopGradient(false); + } } else if (py::isinstance(value_obj)) { auto value_tensor = std::shared_ptr( new imperative::VarBase(false, diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 21f506d03ce68..e9809318cb393 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -1154,6 +1154,18 @@ def set_value5(t, value): msg="The gradient of input should be \n{},\n but reveived {}". format(value_grad, value.grad.numpy())) + # case 6: pass stop_gradient from value to x + x = paddle.zeros([8, 8], dtype='float32') + value = paddle.to_tensor([10], dtype='float32', stop_gradient=False) + + self.assertTrue(x.stop_gradient) + self.assertTrue(x.is_leaf) + + x[0, :] = value + + self.assertTrue(~x.stop_gradient) + self.assertTrue(~x.is_leaf) + def test_static_graph(self): paddle.enable_static()