@@ -513,6 +513,60 @@ def test_api_normal_3(self):
513513 paddle .enable_static ()
514514
515515
516+ class TestAddmmOp_ZeroSize (OpTest ):
517+ def setUp (self ):
518+ self .op_type = "addmm"
519+ self .python_api = paddle .addmm
520+ self .public_python_api = paddle .addmm
521+ self .init_dtype_type ()
522+ self .init_input ()
523+ self .attrs = {
524+ 'Alpha' : 0.5 ,
525+ 'Beta' : 2.0 ,
526+ }
527+ self .outputs = {
528+ 'Out' : self .attrs ['Beta' ] * self .inputs ['Input' ]
529+ + self .attrs ['Alpha' ] * np .dot (self .inputs ['X' ], self .inputs ['Y' ])
530+ }
531+
532+ def init_input (self ):
533+ # result shape: [20, 100]
534+ self .inputs = {
535+ 'Input' : np .random .random (100 ).astype (self .dtype ),
536+ 'X' : np .random .random ((20 , 0 )).astype (self .dtype ),
537+ 'Y' : np .random .random ((0 , 100 )).astype (self .dtype ),
538+ }
539+
540+ def init_dtype_type (self ):
541+ self .dtype = np .float64
542+
543+ def test_check_output (self ):
544+ self .check_output (check_pir = True )
545+
546+ def test_check_grad_normal (self ):
547+ self .check_grad (['Input' , 'X' , 'Y' ], 'Out' , check_pir = True )
548+
549+
550+ class TestAddmmOp_ZeroSize2 (TestAddmmOp_ZeroSize ):
551+ def init_input (self ):
552+ # result shape: [20, 0]
553+ self .inputs = {
554+ 'Input' : np .random .random (0 ).astype (self .dtype ),
555+ 'X' : np .random .random ((20 , 100 )).astype (self .dtype ),
556+ 'Y' : np .random .random ((100 , 0 )).astype (self .dtype ),
557+ }
558+
559+
560+ class TestAddmmOp_ZeroSize3 (TestAddmmOp_ZeroSize ):
561+ def init_input (self ):
562+ # result shape: [0, 0]
563+ self .inputs = {
564+ 'Input' : np .random .random (0 ).astype (self .dtype ),
565+ 'X' : np .random .random ((0 , 100 )).astype (self .dtype ),
566+ 'Y' : np .random .random ((100 , 0 )).astype (self .dtype ),
567+ }
568+
569+
516570if __name__ == "__main__" :
517571 paddle .enable_static ()
518572 unittest .main ()
0 commit comments