Skip to content
19 changes: 18 additions & 1 deletion paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -39,6 +39,23 @@ void CholeskySolveGradKernel(const Context& dev_ctx,
bool upper,
DenseTensor* dx,
DenseTensor* dy) {
if (dout.numel() == 0) {
if (dx) {
dev_ctx.template Alloc<T>(dx);
if (dx->numel() != 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
}
}
if (dy) {
dev_ctx.template Alloc<T>(dy);
if (dy->numel() != 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
}
}
return;
}
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ void CholeskySolveKernel(const Context& dev_ctx,
const DenseTensor& y,
bool upper,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
Expand Down
58 changes: 58 additions & 0 deletions test/legacy_test/test_cholesky_solve_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,5 +282,63 @@ def test_errors_2(self):
self.assertRaises(ValueError, paddle.linalg.cholesky_solve, x7, y7)


# API function test
class TestCholeskySolveAPIZeroSize(unittest.TestCase):
def setUp(self):
np.random.seed(2025)
self.place = [paddle.CPUPlace()]
self.dtype = "float64"
self.upper = True
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
self.init_shape()

def init_shape(self):
self.x_shape = [10, 0]
self.y_shape = [10, 10]
self.expected_shape = [10, 0]
# test in dynamic mode

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
x_np = np.random.random(self.x_shape).astype(self.dtype)
y_np = np.random.random(self.y_shape).astype(self.dtype)

x = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.to_tensor(y_np, stop_gradient=False)
z = paddle.linalg.cholesky_solve(x, y, upper=self.upper)
loss = paddle.sum(z)
loss.backward()

self.assertEqual(z.shape, self.expected_shape)
self.assertEqual(x.shape, x.grad.shape)
self.assertEqual(y.shape, y.grad.shape)

for idx, place in enumerate(self.place):
run(place)


class TestCholeskySolveAPIZeroSize1(TestCholeskySolveAPIZeroSize):
def init_shape(self):
self.x_shape = [0, 6]
self.y_shape = [0, 0]
self.expected_shape = [0, 6]


class TestCholeskySolveAPIZeroSize2(TestCholeskySolveAPIZeroSize):
def init_shape(self):
self.x_shape = [1, 10, 6]
self.y_shape = [0, 10, 10]
self.expected_shape = [0, 10, 6]


class TestCholeskySolveAPIZeroSize3(TestCholeskySolveAPIZeroSize):
def init_shape(self):
self.x_shape = [0, 0, 0]
self.y_shape = [0, 0, 0]
self.expected_shape = [0, 0, 0]


if __name__ == "__main__":
unittest.main()
Loading