File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -308,6 +308,12 @@ struct AMaxOrAMinGradFunctor {
308308 mask.sum (axis).reshape (dy->dimensions ()).broadcast (dim);
309309 return ;
310310 }
311+
312+ if (rank == 0 ) {
313+ dx->device (place) = dy->broadcast (dim) * mask;
314+ return ;
315+ }
316+
311317 // axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank)
312318 HANDLE_AXIS_DIM (3 , 2 );
313319 HANDLE_AXIS_DIM (4 , 2 );
Original file line number Diff line number Diff line change @@ -139,6 +139,33 @@ def _test_dygraph(func):
139139 # test two minimum or maximum elements
140140
141141
142+ class TestMaxMinAmaxAminAPI_AxisWithOne1 (TestMaxMinAmaxAminAPI ):
143+ def init_case (self ):
144+ self .x_np = np .random .randn (1 , 5 , 10 ).astype (np .float32 )
145+ self .shape = [1 , 5 , 10 ]
146+ self .dtype = 'float32'
147+ self .axis = 0
148+ self .keepdim = False
149+
150+
151+ class TestMaxMinAmaxAminAPI_AxisWithOne2 (TestMaxMinAmaxAminAPI ):
152+ def init_case (self ):
153+ self .x_np = np .random .randn (1 , 5 , 10 ).astype (np .float32 )
154+ self .shape = [1 , 5 , 10 ]
155+ self .dtype = 'float32'
156+ self .axis = 0
157+ self .keepdim = True
158+
159+
160+ class TestMaxMinAmaxAminAPI_AxisWithOne3 (TestMaxMinAmaxAminAPI ):
161+ def init_case (self ):
162+ self .x_np = np .random .randn (1 , 1 , 10 ).astype (np .float32 )
163+ self .shape = [1 , 1 , 10 ]
164+ self .dtype = 'float32'
165+ self .axis = (0 , 1 )
166+ self .keepdim = False
167+
168+
142169class TestMaxMinAmaxAminAPI_ZeroDim (TestMaxMinAmaxAminAPI ):
143170 def init_case (self ):
144171 self .x_np = np .array (0.5 )
You can’t perform that action at this time.
0 commit comments