@@ -31,14 +31,19 @@ def setUp(self):
3131 self .op_type = "multi_dot"
3232 self .python_api = paddle .linalg .multi_dot
3333 self .dtype = self .get_dtype ()
34+ self .init_shape ()
3435 self .get_inputs_and_outputs ()
3536
37+ def init_shape (self ):
38+ self .A_shape = (2 , 8 )
39+ self .B_shape = (8 , 4 )
40+
3641 def get_dtype (self ):
3742 return "float64"
3843
3944 def get_inputs_and_outputs (self ):
40- self .A = np .random .random (( 2 , 8 ) ).astype (self .dtype )
41- self .B = np .random .random (( 8 , 4 ) ).astype (self .dtype )
45+ self .A = np .random .random (self . A_shape ).astype (self .dtype )
46+ self .B = np .random .random (self . B_shape ).astype (self .dtype )
4247 self .inputs = {'X' : [('x0' , self .A ), ('x1' , self .B )]}
4348 self .outputs = {'Out' : multi_dot ([self .A , self .B ])}
4449
@@ -55,6 +60,36 @@ def get_dtype(self):
5560 return "float16"
5661
5762
63+ class TestMultiDotOp_ZeroSize1 (TestMultiDotOp ):
64+ def get_inputs_and_outputs (self ):
65+ # result shape: [2, 3]
66+ self .A = np .random .random ((2 , 10 )).astype (self .dtype )
67+ self .B = np .random .random ((10 , 0 )).astype (self .dtype )
68+ self .C = np .random .random ((0 , 3 )).astype (self .dtype )
69+ self .inputs = {'X' : [('x0' , self .A ), ('x1' , self .B ), ('x2' , self .C )]}
70+ self .outputs = {'Out' : multi_dot ([self .A , self .B , self .C ])}
71+
72+ def test_check_grad (self ):
73+ self .check_grad (['x0' ], 'Out' , check_pir = True )
74+ self .check_grad (['x1' ], 'Out' , check_pir = True )
75+ self .check_grad (['x2' ], 'Out' , check_pir = True )
76+
77+
78+ class TestMultiDotOp_ZeroSize2 (TestMultiDotOp ):
79+ def get_inputs_and_outputs (self ):
80+ # result shape: [0, 3]
81+ self .A = np .random .random ((0 , 10 )).astype (self .dtype )
82+ self .B = np .random .random ((10 , 4 )).astype (self .dtype )
83+ self .C = np .random .random ((4 , 3 )).astype (self .dtype )
84+ self .inputs = {'X' : [('x0' , self .A ), ('x1' , self .B ), ('x2' , self .C )]}
85+ self .outputs = {'Out' : multi_dot ([self .A , self .B , self .C ])}
86+
87+ def test_check_grad (self ):
88+ self .check_grad (['x0' ], 'Out' , check_pir = True )
89+ self .check_grad (['x1' ], 'Out' , check_pir = True )
90+ self .check_grad (['x2' ], 'Out' , check_pir = True )
91+
92+
5893@unittest .skipIf (
5994 not core .is_compiled_with_cuda ()
6095 or not core .is_bfloat16_supported (core .CUDAPlace (0 )),
0 commit comments