Skip to content

Commit 26992ec

Browse files
authored
fix solve grad (#72806)
1 parent d048598 commit 26992ec

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

paddle/phi/kernels/impl/solve_grad_kernel_impl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/backends/gpu/gpu_context.h"
1919
#include "paddle/phi/kernels/expand_as_kernel.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/blas/blas.h"
2122
#include "paddle/phi/kernels/funcs/math_function.h"
2223
#include "paddle/phi/kernels/funcs/matrix_solve.h"
@@ -78,6 +79,24 @@ void SolveGradKernel(const Context& dev_ctx,
7879
const DenseTensor& dout,
7980
DenseTensor* dx,
8081
DenseTensor* dy) {
82+
if (dout.numel() == 0) {
83+
if (dx) {
84+
dev_ctx.template Alloc<T>(dx);
85+
if (dx->numel() != 0) {
86+
phi::Full<T, Context>(
87+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
88+
}
89+
}
90+
if (dy) {
91+
dev_ctx.template Alloc<T>(dy);
92+
if (dy->numel() != 0) {
93+
phi::Full<T, Context>(
94+
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
95+
}
96+
}
97+
return;
98+
}
99+
81100
bool is_vector = false;
82101
is_vector = is_vector_rhs(x, y);
83102
DenseTensor tmp_y;

test/legacy_test/test_solve_op.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,22 +985,39 @@ def run(place, x_shape, y_shape):
985985
input_x_np = np.random.random(x_shape).astype(self.dtype)
986986
input_y_np = np.random.random(y_shape).astype(self.dtype)
987987

988-
tensor_input_x = paddle.to_tensor(input_x_np)
989-
tensor_input_y = paddle.to_tensor(input_y_np)
988+
tensor_input_x = paddle.to_tensor(
989+
input_x_np, stop_gradient=False
990+
)
991+
tensor_input_y = paddle.to_tensor(
992+
input_y_np, stop_gradient=False
993+
)
990994

991995
numpy_output = np.linalg.solve(input_x_np, input_y_np)
992996
paddle_output = paddle.linalg.solve(
993-
tensor_input_x, tensor_input_y, left=False
997+
tensor_input_x, tensor_input_y, left=True
994998
)
995999
np.testing.assert_allclose(
9961000
numpy_output, paddle_output.numpy(), rtol=0.0001
9971001
)
9981002
self.assertEqual(
9991003
numpy_output.shape, paddle_output.numpy().shape
10001004
)
1005+
loss = paddle.sum(paddle_output)
1006+
loss.backward()
1007+
np.testing.assert_allclose(
1008+
tensor_input_x.grad.shape, tensor_input_x.shape
1009+
)
1010+
np.testing.assert_allclose(
1011+
tensor_input_y.grad.shape, tensor_input_y.shape
1012+
)
10011013

10021014
for place in self.place:
1015+
run(place, x_shape=[1, 10, 10], y_shape=[1, 10, 10])
1016+
run(place, x_shape=[0, 10, 10], y_shape=[0, 10, 10])
1017+
run(place, x_shape=[0, 10, 10], y_shape=[1, 10, 10])
10031018
run(place, x_shape=[10, 0, 0], y_shape=[10, 0, 0])
1019+
run(place, x_shape=[10, 1, 1], y_shape=[10, 1, 0])
1020+
10041021
with self.assertRaises(ValueError) as context:
10051022
run(place, x_shape=[10, 0, 0], y_shape=[10])
10061023

0 commit comments

Comments
 (0)