diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index bd9d3c571985..ee352fffbf27 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1596,7 +1596,7 @@ def argsort(a, axis=-1, kind=None, order=None): if order is not None: raise NotImplementedError("order not supported here") - return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64') + return _api_internal.argsort(a, axis, True, 'int64') @set_module('mxnet.ndarray.numpy') diff --git a/src/api/operator/numpy/np_ordering_op.cc b/src/api/operator/numpy/np_ordering_op.cc index 5ca5380a1d06..bf9e8033448a 100644 --- a/src/api/operator/numpy/np_ordering_op.cc +++ b/src/api/operator/numpy/np_ordering_op.cc @@ -44,14 +44,45 @@ MXNET_REGISTER_API("_npi.sort") attrs.parsed = std::move(param); attrs.op = op; - // input + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - //output + int num_outputs = 0; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); }); +MXNET_REGISTER_API("_npi.argsort") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_argsort"); + nnvm::NodeAttrs attrs; + op::ArgSortParam param; + + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + param.is_ascend = args[2].operator bool(); + if (args[3].type_code() == kNull) { + param.dtype = mshadow::kFloat32; + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + + attrs.parsed = std::move(param); + attrs.op = op; + + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + + int num_outputs = 0; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + } // namespace mxnet diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 05dcdb2a3e99..f7226ac07d9f 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -137,6 +138,15 @@ struct ArgSortParam : public dmlc::Parameter { " \"both\". An error will be raised if the selected data type cannot precisely " "represent the indices."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, is_ascend_s, dtype_s; + axis_s << axis; + is_ascend_s << is_ascend; + dtype_s << dtype; + (*dict)["axis"] = axis_s.str(); + (*dict)["is_ascend_s"] = is_ascend_s.str(); + (*dict)["dtype"] = String2MXNetTypeWithBool(dtype); + } }; inline void ParseTopKParam(const TShape& src_shape,