From b5a974a84eda76b087edcecf6c84f8a4bca577f7 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 30 Jan 2020 01:22:05 +0000 Subject: [PATCH] fix np.argmax/argmin output data type --- src/operator/numpy/np_broadcast_reduce_op.h | 48 +++++++++++++++++++ .../numpy/np_broadcast_reduce_op_index.cc | 18 +++++-- .../numpy/np_broadcast_reduce_op_index.cu | 4 +- tests/python/unittest/test_numpy_op.py | 2 + 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 38ee5a46e8a2..1099d94735d0 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -357,6 +357,54 @@ void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeBoolImpl(ctx, inputs, req, outputs, small); } +template +void NumpySearchAxisCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const ReduceAxisParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + int axis = inputs[0].ndim(); + TBlob input = inputs[0]; + if (param.axis.has_value()) { + axis = param.axis.value(); + } else { + // If global reduction, reshape the input tensor into 2D shape (1, inputs[0].shape_.Size()) + // and search on axis = 1. + mxnet::TShape shape_2d(2, 1); + shape_2d[1] = input.shape_.Size(); + input = TBlob(input.dptr_, shape_2d, input.dev_mask(), input.type_flag_, input.dev_id()); + axis = 1; + } + + axis = CheckAxis(axis, input.shape_.ndim()); + if (inputs[0].shape_.ndim() != 0) { + if (param.axis.has_value()) { + // cannot do argmax in an empty dimension + CHECK_NE(inputs[0].shape_[axis], 0) + << "searching input tensor of shape " << inputs[0].shape_ + << " along axis = " << axis << " of zero dim-size is not allowed"; + } else { + // cannot do argmax on an empty array + CHECK_NE(inputs[0].shape_.Size(), 0U) << "attempt to search an empty sequence"; + } + } + + if (input.shape_.Size() == 0U) return; // zero-size tensor + mxnet::TShape shape = AxisShapeCompact(input.shape_, &axis, false); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor out = + outputs[0].get_with_shape(Shape2(shape[0], shape[2]), s); + Tensor in = + input.get_with_shape(shape.get<3>(), s); + CHECK(req[0] != kAddTo) << "AddTo is not supported"; + ASSIGN_DISPATCH(out, req[0], tcast(reduce_with_axis(in, 1))); + }); +} + template inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cc b/src/operator/numpy/np_broadcast_reduce_op_index.cc index 15831c7e79ba..16c4fec8743b 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_index.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cc @@ -46,14 +46,24 @@ bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs, return shape_is_known(out_attrs->at(0)); } +bool ArgMinMaxType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE(in_attrs->at(0), -1); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); + return out_attrs->at(0) != -1; +} + NNVM_REGISTER_OP(_npi_argmax) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", NumpyReduceAxisShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", ArgMinMaxType) .add_argument("data", "NDArray-or-Symbol", "The input") -.set_attr("FCompute", SearchAxisCompute) +.set_attr("FCompute", NumpySearchAxisCompute) .set_attr("FGradient", MakeZeroGradNodes) .add_arguments(ReduceAxisParam::__FIELDS__()); @@ -62,9 +72,9 @@ NNVM_REGISTER_OP(_npi_argmin) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", NumpyReduceAxisShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", ArgMinMaxType) .add_argument("data", "NDArray-or-Symbol", "The input") -.set_attr("FCompute", SearchAxisCompute) +.set_attr("FCompute", NumpySearchAxisCompute) .set_attr("FGradient", MakeZeroGradNodes) .add_arguments(ReduceAxisParam::__FIELDS__()); diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu index 0420133ee7c0..206f09b8c185 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_index.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu @@ -28,10 +28,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_argmax) -.set_attr("FCompute", SearchAxisCompute); +.set_attr("FCompute", NumpySearchAxisCompute); NNVM_REGISTER_OP(_npi_argmin) -.set_attr("FCompute", SearchAxisCompute); +.set_attr("FCompute", NumpySearchAxisCompute); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f8d6817ea444..c7669a1a7140 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3127,6 +3127,7 @@ def hybrid_forward(self, F, x): else: mx_ret = getattr(np, op_name)(a, axis=axis) np_ret = getattr(_np, op_name)(a.asnumpy(), axis=axis) + assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) for hybridize in [False, True]: @@ -3142,6 +3143,7 @@ def hybrid_forward(self, F, x): pass else: mx_ret = net(a) + assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret)