@@ -216,6 +216,69 @@ def config(self):
216216 self .n = 32
217217
218218
219+ class TestMatrixPowerOpZeroSize (TestMatrixPowerOp ):
220+ def config (self ):
221+ self .matrix_shape = [0 , 0 ]
222+ self .dtype = "float32"
223+ self .n = 32
224+
225+
226+ class TestMatrixPowerOpZeroSize1 (TestMatrixPowerOp ):
227+ def config (self ):
228+ self .matrix_shape = [0 , 0 ]
229+ self .dtype = "float32"
230+ self .n = 0
231+
232+
233+ class TestMatrixPowerOpZeroSize2 (TestMatrixPowerOp ):
234+ def config (self ):
235+ self .matrix_shape = [0 , 0 ]
236+ self .dtype = "float32"
237+ self .n = - 1
238+
239+
240+ class TestMatrixPowerOpBatchedZeroSize1 (TestMatrixPowerOp ):
241+ def config (self ):
242+ self .matrix_shape = [2 , 0 , 4 , 4 ]
243+ self .dtype = "float32"
244+ self .n = 4
245+
246+
247+ class TestMatrixPowerOpBatchedZeroSize2 (TestMatrixPowerOp ):
248+ def config (self ):
249+ self .matrix_shape = [2 , 0 , 4 , 4 ]
250+ self .dtype = "float32"
251+ self .n = 0
252+
253+
254+ class TestMatrixPowerOpBatchedZeroSize3 (TestMatrixPowerOp ):
255+ def config (self ):
256+ self .matrix_shape = [2 , 0 , 4 , 4 ]
257+ self .dtype = "float32"
258+ self .n = - 1
259+
260+
261+ class TestMatrixPowerOpBatchedZeroSize4 (TestMatrixPowerOp ):
262+ def config (self ):
263+ self .matrix_shape = [2 , 6 , 0 , 0 ]
264+ self .dtype = "float32"
265+ self .n = 1
266+
267+
268+ class TestMatrixPowerOpBatchedZeroSize5 (TestMatrixPowerOp ):
269+ def config (self ):
270+ self .matrix_shape = [2 , 6 , 0 , 0 ]
271+ self .dtype = "float32"
272+ self .n = 0
273+
274+
275+ class TestMatrixPowerOpBatchedZeroSize6 (TestMatrixPowerOp ):
276+ def config (self ):
277+ self .matrix_shape = [2 , 6 , 0 , 0 ]
278+ self .dtype = "float32"
279+ self .n = - 1
280+
281+
219282@unittest .skipIf (
220283 core .is_compiled_with_xpu (),
221284 "Skip complex due to lack of mean support" ,
@@ -563,7 +626,7 @@ def _test_matrix_power_empty_static(self, place):
563626 self .assertEqual (res [0 ].shape , (0 , 0 ))
564627 self .assertEqual (res [1 ].shape , (2 , 3 , 0 , 0 ))
565628
566- def _test_matrix_power_empty_dynamtic (self ):
629+ def _test_matrix_power_empty_dynamic (self ):
567630 with dygraph_guard ():
568631 x2 = paddle .full ((0 , 6 ), 1.0 , dtype = 'float32' )
569632 x3 = paddle .full ((6 , 0 ), 1.0 , dtype = 'float32' )
@@ -581,7 +644,7 @@ def _test_matrix_power_empty_dynamtic(self):
581644 def test_matrix_power_empty_tensor (self ):
582645 for place in self ._get_places ():
583646 self ._test_matrix_power_empty_static (place )
584- self ._test_matrix_power_empty_dynamtic ()
647+ self ._test_matrix_power_empty_dynamic ()
585648
586649
587650if __name__ == "__main__" :
0 commit comments