Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* argsort FFI
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Mar 24, 2020
1 parent 68ce48a commit 2a504b2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ def argsort(a, axis=-1, kind=None, order=None):
if order is not None:
raise NotImplementedError("order not supported here")

return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')
return _api_internal.argsort(a, axis, True, 'int64')


@set_module('mxnet.ndarray.numpy')
Expand Down
35 changes: 33 additions & 2 deletions src/api/operator/numpy/np_ordering_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,45 @@ MXNET_REGISTER_API("_npi.sort")

attrs.parsed = std::move(param);
attrs.op = op;
// input

int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
//output

int num_outputs = 0;
SetAttrDict<op::SortParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

MXNET_REGISTER_API("_npi.argsort")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argsort");
nnvm::NodeAttrs attrs;
op::ArgSortParam param;

if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
param.is_ascend = args[2].operator bool();
if (args[3].type_code() == kNull) {
param.dtype = mshadow::kFloat32;
} else {
param.dtype = String2MXNetTypeWithBool(args[1].operator std::string());
}

attrs.parsed = std::move(param);
attrs.op = op;

int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};

int num_outputs = 0;
SetAttrDict<op::ArgSortParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
10 changes: 10 additions & 0 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <mshadow/tensor.h>
#include <algorithm>
#include <vector>
#include <string>
#include <type_traits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
Expand Down Expand Up @@ -137,6 +138,15 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
" \"both\". An error will be raised if the selected data type cannot precisely "
"represent the indices.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, is_ascend_s, dtype_s;
axis_s << axis;
is_ascend_s << is_ascend;
dtype_s << dtype;
(*dict)["axis"] = axis_s.str();
(*dict)["is_ascend_s"] = is_ascend_s.str();
(*dict)["dtype"] = String2MXNetTypeWithBool(dtype);
}
};

inline void ParseTopKParam(const TShape& src_shape,
Expand Down

0 comments on commit 2a504b2

Please sign in to comment.