Skip to content

Commit

Permalink
fix arg_max to select first index (PaddlePaddle#44521)
Browse files Browse the repository at this point in the history
  • Loading branch information
cifar10 authored and Aurelius84 committed Jul 29, 2022
1 parent c9da12f commit 36b9ccb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/arg_max_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class ArgMaxMLUKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc input_desc(
flatten_x, CNNL_LAYOUT_ARRAY, ToCnnlDataType(flatten_x.dtype()));
MLUCnnlReduceDesc reduction_desc(reduce_dims,
CNNL_REDUCE_MAX_LAST_INDEX,
CNNL_REDUCE_MAX,
ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_ONLY_INDICES,
Expand Down
32 changes: 32 additions & 0 deletions python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,38 @@ def test_check_output(self):
self.check_output_with_place(self.place)


class TestArgMaxSameValue1(BaseTestCase):

def initTestCase(self):
self.op_type = 'arg_max'
self.dtype = 'float32'
self.axis = 0

def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = np.array([1, 2, 3, 5, 4, 5]).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}


class TestArgMaxSameValue2(BaseTestCase):

def initTestCase(self):
self.op_type = 'arg_max'
self.dtype = 'float16'
self.axis = 0

def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = np.array([[2, 3, 5, 5], [3, 2, 5, 5]]).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}


# test argmax, dtype: float16
class TestArgMaxFloat16Case1(BaseTestCase):

Expand Down

0 comments on commit 36b9ccb

Please sign in to comment.