@@ -97,6 +97,7 @@ def init_dtype(self):
9797class TestMultiplexOpError (unittest .TestCase ):
9898
9999 def test_errors (self ):
100+ paddle .enable_static ()
100101 with base .program_guard (base .Program (), base .Program ()):
101102 x1 = paddle .static .data (name = 'x1' , shape = [None , 2 ], dtype = 'int64' )
102103 x2 = paddle .static .data (name = 'x2' , shape = [None , 2 ], dtype = 'int64' )
@@ -198,5 +199,40 @@ def init_dtype(self):
198199 self .dtype = np .complex128
199200
200201
202+ class TestMultiplexOp_ZeroSize (OpTest ):
203+ def setUp (self ):
204+ self .op_type = "multiplex"
205+ self .init_dtype ()
206+ self .python_api = paddle .tensor .multiplex
207+ rows = 4
208+ index = np .array ([0 , 2 , 2 , 3 ]).astype ('int32' )
209+ np .random .shuffle (index )
210+ index = np .reshape (index , (rows , 1 ))
211+ ins1 = np .random .random ((rows , 0 )).astype (self .dtype )
212+ ins2 = np .random .random ((rows , 0 )).astype (self .dtype )
213+ ins3 = np .random .random ((rows , 0 )).astype (self .dtype )
214+ ins4 = np .random .random ((rows , 0 )).astype (self .dtype )
215+ self .inputs = {
216+ 'Ids' : index ,
217+ 'X' : [('x1' , ins1 ), ('x2' , ins2 ), ('x3' , ins3 ), ('x4' , ins4 )],
218+ }
219+ # multiplex output
220+ output = np .zeros_like (ins1 )
221+ for i in range (0 , rows ):
222+ k = index [i ][0 ]
223+ if self .inputs ['X' ][k ][1 ][i ].size != 0 :
224+ output [i ] = self .inputs ['X' ][k ][1 ][i ]
225+ self .outputs = {'Out' : output }
226+
227+ def init_dtype (self ):
228+ self .dtype = 'float64'
229+
230+ def test_check_output (self ):
231+ self .check_output (check_pir = True )
232+
233+ def test_check_grad (self ):
234+ self .check_grad (['x1' , 'x2' , 'x3' , 'x4' ], 'Out' , check_pir = True )
235+
236+
201237if __name__ == '__main__' :
202238 unittest .main ()
0 commit comments