@@ -96,6 +96,65 @@ def test_check_grad(self):
9696 self .check_grad_with_place (self .place , {'X' }, 'Out' )
9797
9898
99+ class XPUTestArgsortOp_0D (XPUOpTestWrapper ):
100+ def __init__ (self ):
101+ self .op_name = 'argsort'
102+ self .use_dynamic_create_class = False
103+
104+ class TestArgsortOpCase1 (XPUOpTest ):
105+ def setUp (self ):
106+ self .set_xpu ()
107+ self .op_type = "argsort"
108+ self .place = paddle .XPUPlace (0 )
109+ self .dtype = self .in_type
110+ self .input_shape = 0
111+ self .axis = (
112+ - 1 if not hasattr (self , 'init_axis' ) else self .init_axis
113+ )
114+ self .descending = (
115+ False
116+ if not hasattr (self , 'init_descending' )
117+ else self .init_descending
118+ )
119+
120+ if self .dtype == np .float32 :
121+ self .x = np .random .random (self .input_shape ).astype (
122+ self .dtype
123+ )
124+ else :
125+ self .x = np .random .choice (
126+ low = - 1000 , high = 1000 , size = self .input_shape
127+ ).astype (self .dtype )
128+
129+ self .inputs = {"X" : self .x }
130+ self .attrs = {"axis" : self .axis , "descending" : self .descending }
131+ self .get_output ()
132+ self .outputs = {"Out" : self .sorted_x , "Indices" : self .indices }
133+
134+ def get_output (self ):
135+ if self .descending :
136+ self .indices = np .flip (
137+ np .argsort (self .x , kind = 'heapsort' , axis = self .axis ),
138+ self .axis ,
139+ )
140+ self .sorted_x = np .flip (
141+ np .sort (self .x , kind = 'heapsort' , axis = self .axis ), self .axis
142+ )
143+ else :
144+ self .indices = np .argsort (self .x , kind = 'heapsort' , axis = self .axis )
145+ self .sorted_x = np .sort (self .x , kind = 'heapsort' , axis = self .axis )
146+
147+ def set_xpu (self ):
148+ self .__class__ .use_xpu = True
149+ self .__class__ .no_need_check_grad = True
150+
151+ def test_check_output (self ):
152+ self .check_output_with_place (self .place )
153+
154+ def test_check_grad (self ):
155+ self .check_grad_with_place (self .place , {'X' }, 'Out' )
156+
157+
99158class XPUTestArgsortOp_LargeN (XPUOpTestWrapper ):
100159 def __init__ (self ):
101160 self .op_name = 'argsort'
0 commit comments