Skip to content

Commit

Permalink
Fix set value grad (PaddlePaddle#59034)
Browse files Browse the repository at this point in the history
* first fix the UT

* fix set value grad

* polish code

* add static mode backward test

* always has input valuetensor

* add dygraph test
  • Loading branch information
zoooo0820 committed Jan 18, 2024
1 parent d788e9b commit 501e520
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 28 deletions.
44 changes: 19 additions & 25 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
op->SetType("set_value_grad");
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));

if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
}
};

Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,14 @@

- backward_op : set_value_grad
forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out)
args : (Tensor out_grad)
args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
output : Tensor(x_grad)
infer_meta:
func: UnchangedInferMeta
param: [out_grad]
kernel:
func: assign
param: [out_grad]
func: set_value_with_scalar_grad
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]

- backward_op : set_value_with_tensor_grad
forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out)
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/kernels/cpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
CPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::bfloat16,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
17 changes: 17 additions & 0 deletions paddle/phi/kernels/gpu/set_value_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
GPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
22 changes: 22 additions & 0 deletions paddle/phi/kernels/impl/set_value_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,26 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
SetValueGradKernel<T, Context>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
}

} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/kernels/set_value_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ void SetValueGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* value_grad);

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad);
} // namespace phi
31 changes: 31 additions & 0 deletions paddle/phi/kernels/xpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,28 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
SetValueGradKernel<T, Context>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
}

} // namespace phi

PD_REGISTER_KERNEL(set_value_grad,
Expand All @@ -407,3 +429,12 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
int,
int64_t) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
XPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
82 changes: 82 additions & 0 deletions test/legacy_test/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,5 +1978,87 @@ def test_check_grad(self):
self.check_grad_with_place(place, ['Input'], 'Out', check_dygraph=False)


class TestSetValueWithScalarInStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shape = (10, 2)
self.exe = paddle.static.Executor()
self.train_program = paddle.static.Program()
self.startup_program = paddle.static.Program()

def test_value_input_is_scalar(self):
with paddle.static.program_guard(
self.train_program, self.startup_program
):
x = paddle.ones(self.shape)
x.stop_gradient = False
y = x * 1

# mock test case x[0, 0] = 10 with no ValueTensor input
inputs = {
'Input': y,
}
attrs = {
'axes': [0, 1],
'starts': [0, 0],
'ends': [1, 1],
'steps': [1, 1],
'values': [10],
'shape': [1],
}

helper = LayerHelper("set_value")
out = helper.create_variable_for_type_inference(dtype=y.dtype)

helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': out},
attrs=attrs,
)

np_data = np.ones(self.shape).astype('float32')

paddle.static.append_backward(out.sum())
res = self.exe.run(
self.train_program, fetch_list=[out, x.grad_name]
)

np_data[0, 0] = 10
expected_x_grad = np.ones(self.shape)
expected_x_grad[0, 0] = 0

np.testing.assert_array_equal(res[0], np_data)
np.testing.assert_array_equal(res[1], expected_x_grad)


class TestSetValueWithScalarInDygraph(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shape = (10, 2)

def test_value_input_is_scalar(self):
x = paddle.ones(self.shape)
x.stop_gradient = False
y = x * 1

# mock test case x[0, 0] = 10 with no ValueTensor input
out = paddle._C_ops.set_value(
y, [0, 0], [1, 1], [1, 1], [0, 1], [], [], [1], [10.0]
)

loss = out.sum()
loss.backward()

np_data = np.ones(self.shape).astype('float32')
np_data[0, 0] = 10

expected_x_grad = np.ones(self.shape)
expected_x_grad[0, 0] = 0

np.testing.assert_array_equal(out, np_data)
np.testing.assert_array_equal(x.grad, expected_x_grad)


if __name__ == '__main__':
unittest.main()

0 comments on commit 501e520

Please sign in to comment.