@@ -923,5 +923,87 @@ def test_dygraph(self):
923923 print ("The mat is singular" )
924924
925925
926+ class TestSolveOpAPIZeroDimCase (unittest .TestCase ):
927+ def setUp (self ):
928+ np .random .seed (2021 )
929+ self .place = []
930+ self .dtype = "float32"
931+ if (
932+ os .environ .get ('FLAGS_CI_both_cpu_and_gpu' , 'False' ).lower ()
933+ in ['1' , 'true' , 'on' ]
934+ or not core .is_compiled_with_cuda ()
935+ ):
936+ self .place .append (paddle .CPUPlace ())
937+ if core .is_compiled_with_cuda ():
938+ self .place .append (paddle .CUDAPlace (0 ))
939+
940+ def check_static_result (self , place , x_shape , y_shape , np_y_shape ):
941+ paddle .enable_static ()
942+ with base .program_guard (base .Program (), base .Program ()):
943+ paddle_input_x = paddle .static .data (
944+ name = "input_x" , shape = x_shape , dtype = self .dtype
945+ )
946+ paddle_input_y = paddle .static .data (
947+ name = "input_y" , shape = y_shape , dtype = self .dtype
948+ )
949+ paddle_result = paddle .linalg .solve (
950+ paddle_input_x , paddle_input_y , left = False
951+ )
952+
953+ np_input_x = np .random .random (x_shape ).astype (self .dtype )
954+ np_input_y = np .random .random (np_y_shape ).astype (self .dtype )
955+
956+ np_result = np .linalg .solve (np_input_x , np_input_y )
957+
958+ exe = base .Executor (place )
959+ fetches = exe .run (
960+ base .default_main_program (),
961+ feed = {"input_x" : np_input_x , "input_y" : np_input_y },
962+ fetch_list = [paddle_result ],
963+ )
964+ np .testing .assert_allclose (fetches [0 ], np_result , rtol = 0.0001 )
965+
966+ def test_static (self ):
967+ for place in self .place :
968+ self .check_static_result (
969+ place = place ,
970+ x_shape = [10 , 0 , 0 ],
971+ y_shape = [6 , 0 , 0 ],
972+ np_y_shape = [10 , 0 , 0 ],
973+ )
974+ with self .assertRaises (ValueError ) as context :
975+ self .check_static_result (
976+ place = place ,
977+ x_shape = [10 , 0 , 0 ],
978+ y_shape = [10 ],
979+ np_y_shape = [10 ],
980+ )
981+
982+ def test_dygraph (self ):
983+ def run (place , x_shape , y_shape ):
984+ with base .dygraph .guard (place ):
985+ input_x_np = np .random .random (x_shape ).astype (self .dtype )
986+ input_y_np = np .random .random (y_shape ).astype (self .dtype )
987+
988+ tensor_input_x = paddle .to_tensor (input_x_np )
989+ tensor_input_y = paddle .to_tensor (input_y_np )
990+
991+ numpy_output = np .linalg .solve (input_x_np , input_y_np )
992+ paddle_output = paddle .linalg .solve (
993+ tensor_input_x , tensor_input_y , left = False
994+ )
995+ np .testing .assert_allclose (
996+ numpy_output , paddle_output .numpy (), rtol = 0.0001
997+ )
998+ self .assertEqual (
999+ numpy_output .shape , paddle_output .numpy ().shape
1000+ )
1001+
1002+ for place in self .place :
1003+ run (place , x_shape = [10 , 0 , 0 ], y_shape = [10 , 0 , 0 ])
1004+ with self .assertRaises (ValueError ) as context :
1005+ run (place , x_shape = [10 , 0 , 0 ], y_shape = [10 ])
1006+
1007+
9261008if __name__ == "__main__" :
9271009 unittest .main ()
0 commit comments