Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ void MatrixPowerGradKernel(const Context& ctx,
auto Out = &out;
auto dOut = &out_grad;
auto dX = x_grad;
if (x_grad && x_grad->numel() == 0) {
ctx.template Alloc<T>(x_grad);
return;
}

MatrixPowerGradFunction<Context, T>(X, Out, dOut, n, dX, ctx);
}
Expand Down
67 changes: 65 additions & 2 deletions test/legacy_test/test_matrix_power_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,69 @@ def config(self):
self.n = 32


class TestMatrixPowerOpZeroSize(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [0, 0]
self.dtype = "float32"
self.n = 32


class TestMatrixPowerOpZeroSize1(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [0, 0]
self.dtype = "float32"
self.n = 0


class TestMatrixPowerOpZeroSize2(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [0, 0]
self.dtype = "float32"
self.n = -1


class TestMatrixPowerOpBatchedZeroSize1(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 0, 4, 4]
self.dtype = "float32"
self.n = 4


class TestMatrixPowerOpBatchedZeroSize2(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 0, 4, 4]
self.dtype = "float32"
self.n = 0


class TestMatrixPowerOpBatchedZeroSize3(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 0, 4, 4]
self.dtype = "float32"
self.n = -1


class TestMatrixPowerOpBatchedZeroSize4(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 6, 0, 0]
self.dtype = "float32"
self.n = 1


class TestMatrixPowerOpBatchedZeroSize5(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 6, 0, 0]
self.dtype = "float32"
self.n = 0


class TestMatrixPowerOpBatchedZeroSize6(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [2, 6, 0, 0]
self.dtype = "float32"
self.n = -1


@unittest.skipIf(
core.is_compiled_with_xpu(),
"Skip complex due to lack of mean support",
Expand Down Expand Up @@ -563,7 +626,7 @@ def _test_matrix_power_empty_static(self, place):
self.assertEqual(res[0].shape, (0, 0))
self.assertEqual(res[1].shape, (2, 3, 0, 0))

def _test_matrix_power_empty_dynamtic(self):
def _test_matrix_power_empty_dynamic(self):
with dygraph_guard():
x2 = paddle.full((0, 6), 1.0, dtype='float32')
x3 = paddle.full((6, 0), 1.0, dtype='float32')
Expand All @@ -581,7 +644,7 @@ def _test_matrix_power_empty_dynamtic(self):
def test_matrix_power_empty_tensor(self):
for place in self._get_places():
self._test_matrix_power_empty_static(place)
self._test_matrix_power_empty_dynamtic()
self._test_matrix_power_empty_dynamic()


if __name__ == "__main__":
Expand Down
Loading