@@ -684,12 +684,18 @@ def test_value_list_shape2(self):
684684 x = paddle .expand (x , shape = [shape1 , 1 , - 1 , - 1 ])
685685 np .testing .assert_equal (tuple (x .shape ), (- 1 , 1 , - 1 , - 1 ))
686686
687+
687688class TestExpandV2OneDNNOp (OpTest ):
688689 def setUp (self ):
689690 self .op_type = "expand_v2"
690691 self .init_data ()
691- self .x = np .random .random (self .ori_shape ).astype ("float32" )
692- self .attrs = {'shape' : self .shape , 'use_mkldnn' : True }
692+ self .python_api = paddle .expand
693+ self .x = np .zeros (self .ori_shape ).astype ("float32" )
694+ self .attrs = {
695+ 'shape' : self .shape ,
696+ 'use_mkldnn' : True ,
697+ 'dtype' : int (paddle .float32 ),
698+ }
693699 self .set_inputs ()
694700 self .set_additional_inputs ()
695701 output = np .zeros (self .expect_shape ).astype ("float32" )
@@ -702,30 +708,90 @@ def set_additional_inputs(self):
702708 pass
703709
704710 def init_data (self ):
705- self .ori_shape = [1 , 1 , 1 , 140 ]
706- self .shape = [2 , 3 , 0 , 140 ]
707- self .expect_shape = [2 , 3 , 0 , 140 ]
711+ self .ori_shape = [1 , 0 , 1 , 140 ]
712+ self .shape = [1 , 0 , 1 , 140 ]
713+ self .expect_shape = [1 , 0 , 1 , 140 ]
708714
709715 def test_check_output (self ):
710- self .check_output_with_place (core .CPUPlace (), check_pir_onednn = True ,check_dygraph = False )
711-
712- # def test_check_grad(self):
713- # self.check_grad_with_place(
714- # core.CPUPlace(), ["X"], "Out", check_pir_onednn=True, check_dygraph=False
715- # )
716+ self .check_output_with_place (
717+ core .CPUPlace (), check_pir_onednn = True , check_dygraph = False
718+ )
719+
720+ def test_check_grad (self ):
721+ self .check_grad_with_place (
722+ core .CPUPlace (),
723+ ["X" ],
724+ "Out" ,
725+ check_pir_onednn = True ,
726+ check_dygraph = False ,
727+ )
728+
729+
716730class TestExpandV2ZeroSizeOneDNNOp (TestExpandV2OneDNNOp ):
717731
718732 def init_data (self ):
719- self .ori_shape = (1 , 3 )
720- self .shape = (0 , 3 )
721- self .expect_shape = (0 , 3 )
733+ self .ori_shape = (0 , 130 )
734+ self .shape = (4 , 0 , 130 )
735+ self .expect_shape = (4 , 0 , 130 )
736+
722737
723738class TestExpandV2ZeroSizeOneDNNOp2 (TestExpandV2OneDNNOp ):
724739
725740 def init_data (self ):
726- self .ori_shape = (1 , 3 )
727- self .shape = (1 , 0 , 3 )
728- self .expect_shape = (1 , 0 , 3 )
741+ self .ori_shape = (0 , 1 , 8 )
742+ self .shape = (0 , 8 , 8 )
743+ self .expect_shape = (0 , 8 , 8 )
744+
745+
746+ class TestExpandV2GPUOp (TestExpandV2OneDNNOp ):
747+ def test_check_output (self ):
748+ self .check_output_with_place (core .CUDAPlace (0 ), check_dygraph = True )
749+
750+ def test_check_grad (self ):
751+ if core .is_compiled_with_cuda ():
752+ self .check_grad_with_place (
753+ core .CUDAPlace (0 ), ["X" ], "Out" , check_dygraph = True
754+ )
755+
756+
757+ class TestExpandV2ZeroSizeGPUOp (TestExpandV2GPUOp ):
758+ def init_data (self ):
759+ self .ori_shape = (0 , 130 )
760+ self .shape = (4 , 0 , 130 )
761+ self .expect_shape = (4 , 0 , 130 )
762+
763+
764+ class TestExpandV2ZeroSizeGPUOp2 (TestExpandV2GPUOp ):
765+ def init_data (self ):
766+ self .ori_shape = (0 , 1 )
767+ self .shape = (0 , 8 )
768+ self .expect_shape = (0 , 8 )
769+
770+
771+ class TestExpandV2CPUOp (TestExpandV2OneDNNOp ):
772+ def test_check_output (self ):
773+ self .check_output_with_place (core .CPUPlace (), check_dygraph = True )
774+
775+ def test_check_grad (self ):
776+ if core .is_compiled_with_cuda ():
777+ self .check_grad_with_place (
778+ core .CPUPlace (), ["X" ], "Out" , check_dygraph = True
779+ )
780+
781+
782+ class TestExpandV2CPUOp1 (TestExpandV2CPUOp ):
783+ def init_data (self ):
784+ self .ori_shape = (0 , 1 )
785+ self .shape = (0 , 8 )
786+ self .expect_shape = (0 , 8 )
787+
788+
789+ class TestExpandV2CPUOp2 (TestExpandV2CPUOp ):
790+ def init_data (self ):
791+ self .ori_shape = (0 , 130 )
792+ self .shape = (4 , 0 , 130 )
793+ self .expect_shape = (4 , 0 , 130 )
794+
729795
730796if __name__ == "__main__" :
731797 paddle .enable_static ()
0 commit comments