Skip to content

Commit e7ae67a

Browse files
authored
[0-size Tensor No.110] Add 0-size Tensor support for paddle.linalg.matrix_power API. (#72790)
* fix MatrixPowerGradKernel , add unittest - MatrixPowerGradKernel support 0-size Tensor * Update matrix_power_grad_kernel_impl.h fix error * Update matrix_power_grad_kernel_impl.h
1 parent 988669d commit e7ae67a

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ void MatrixPowerGradKernel(const Context& ctx,
193193
auto Out = &out;
194194
auto dOut = &out_grad;
195195
auto dX = x_grad;
196+
if (x_grad && x_grad->numel() == 0) {
197+
ctx.template Alloc<T>(x_grad);
198+
return;
199+
}
196200

197201
MatrixPowerGradFunction<Context, T>(X, Out, dOut, n, dX, ctx);
198202
}

test/legacy_test/test_matrix_power_op.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,69 @@ def config(self):
216216
self.n = 32
217217

218218

219+
class TestMatrixPowerOpZeroSize(TestMatrixPowerOp):
220+
def config(self):
221+
self.matrix_shape = [0, 0]
222+
self.dtype = "float32"
223+
self.n = 32
224+
225+
226+
class TestMatrixPowerOpZeroSize1(TestMatrixPowerOp):
227+
def config(self):
228+
self.matrix_shape = [0, 0]
229+
self.dtype = "float32"
230+
self.n = 0
231+
232+
233+
class TestMatrixPowerOpZeroSize2(TestMatrixPowerOp):
234+
def config(self):
235+
self.matrix_shape = [0, 0]
236+
self.dtype = "float32"
237+
self.n = -1
238+
239+
240+
class TestMatrixPowerOpBatchedZeroSize1(TestMatrixPowerOp):
241+
def config(self):
242+
self.matrix_shape = [2, 0, 4, 4]
243+
self.dtype = "float32"
244+
self.n = 4
245+
246+
247+
class TestMatrixPowerOpBatchedZeroSize2(TestMatrixPowerOp):
248+
def config(self):
249+
self.matrix_shape = [2, 0, 4, 4]
250+
self.dtype = "float32"
251+
self.n = 0
252+
253+
254+
class TestMatrixPowerOpBatchedZeroSize3(TestMatrixPowerOp):
255+
def config(self):
256+
self.matrix_shape = [2, 0, 4, 4]
257+
self.dtype = "float32"
258+
self.n = -1
259+
260+
261+
class TestMatrixPowerOpBatchedZeroSize4(TestMatrixPowerOp):
262+
def config(self):
263+
self.matrix_shape = [2, 6, 0, 0]
264+
self.dtype = "float32"
265+
self.n = 1
266+
267+
268+
class TestMatrixPowerOpBatchedZeroSize5(TestMatrixPowerOp):
269+
def config(self):
270+
self.matrix_shape = [2, 6, 0, 0]
271+
self.dtype = "float32"
272+
self.n = 0
273+
274+
275+
class TestMatrixPowerOpBatchedZeroSize6(TestMatrixPowerOp):
276+
def config(self):
277+
self.matrix_shape = [2, 6, 0, 0]
278+
self.dtype = "float32"
279+
self.n = -1
280+
281+
219282
@unittest.skipIf(
220283
core.is_compiled_with_xpu(),
221284
"Skip complex due to lack of mean support",
@@ -563,7 +626,7 @@ def _test_matrix_power_empty_static(self, place):
563626
self.assertEqual(res[0].shape, (0, 0))
564627
self.assertEqual(res[1].shape, (2, 3, 0, 0))
565628

566-
def _test_matrix_power_empty_dynamtic(self):
629+
def _test_matrix_power_empty_dynamic(self):
567630
with dygraph_guard():
568631
x2 = paddle.full((0, 6), 1.0, dtype='float32')
569632
x3 = paddle.full((6, 0), 1.0, dtype='float32')
@@ -581,7 +644,7 @@ def _test_matrix_power_empty_dynamtic(self):
581644
def test_matrix_power_empty_tensor(self):
582645
for place in self._get_places():
583646
self._test_matrix_power_empty_static(place)
584-
self._test_matrix_power_empty_dynamtic()
647+
self._test_matrix_power_empty_dynamic()
585648

586649

587650
if __name__ == "__main__":

0 commit comments

Comments
 (0)