diff --git a/src/api/operator/numpy/np_bincount_op.cc b/src/api/operator/numpy/np_bincount_op.cc index e98c1a4a750d..afa3278c24e4 100644 --- a/src/api/operator/numpy/np_bincount_op.cc +++ b/src/api/operator/numpy/np_bincount_op.cc @@ -21,6 +21,7 @@ * \file np_bincount_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_bincount_op.cc */ +#include #include "../utils.h" #include "../../../operator/numpy/np_bincount_op-inl.h" @@ -38,17 +39,21 @@ MXNET_REGISTER_API("_npi.bincount") param.minlength = args[2].operator int64_t(); param.has_weights = false; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - attrs.parsed = param; // remove std::move() ci error: trivially-copyable type + int num_inputs = 1; + attrs.parsed = param; attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } else { param.minlength = args[2].operator int64_t(); param.has_weights = true; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - attrs.parsed = param; // remove std::move() ci error: trivially-copyable type + int num_inputs = 2; + attrs.parsed = param; attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } }); diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc index 32f70e8e2fc2..dea510a41608 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc @@ -21,6 +21,8 @@ * \file np_broadcast_reduce_op_boolean.cc * \brief Implementation of the API of functions in src/operator/numpy/np_broadcast_reduce_op_boolean.cc */ +#include +#include #include "../utils.h" #include "../../../operator/numpy/np_broadcast_reduce_op.h" @@ -45,11 +47,16 @@ MXNET_REGISTER_API("_npi.all") } param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke( - op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } }); MXNET_REGISTER_API("_npi.any") @@ -71,11 +78,16 @@ MXNET_REGISTER_API("_npi.any") } param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke( - op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_percentile_op.cc b/src/api/operator/numpy/np_percentile_op.cc index 038e2d39c3f4..634ee092c64d 100644 --- a/src/api/operator/numpy/np_percentile_op.cc +++ b/src/api/operator/numpy/np_percentile_op.cc @@ -21,12 +21,14 @@ * \file np_percentile_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_percentile_op.cc */ +#include +#include #include "../utils.h" #include "../../../operator/numpy/np_percentile_op-inl.h" namespace mxnet { -inline int String2MXnetPercentileType(const std::string& s) { +inline int String2MXNetPercentileType(const std::string& s) { using namespace op; if (s == "linear") { return percentile_enum::kLinear; @@ -62,22 +64,34 @@ MXNET_REGISTER_API("_npi.percentile") } else { param.axis = Tuple(args[2].operator ObjectRef()); } - param.interpolation = String2MXnetPercentileType(args[3].operator std::string()); + param.interpolation = String2MXNetPercentileType(args[3].operator std::string()); param.keepdims = args[4].operator bool(); if (args[1].type_code() == kDLInt || args[1].type_code() == kDLFloat) { param.q_scalar = args[1].operator double(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } } else { param.q_scalar = dmlc::nullopt; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } } }); diff --git a/src/api/operator/op_utils.cc b/src/api/operator/op_utils.cc index 220a880336db..bb54662e7a62 100644 --- a/src/api/operator/op_utils.cc +++ b/src/api/operator/op_utils.cc @@ -24,6 +24,7 @@ #include "op_utils.h" #include +#include "../../operator/numpy/np_percentile_op-inl.h" namespace mxnet { @@ -52,4 +53,24 @@ std::string String2MXNetTypeWithBool(int dtype) { return ""; } +std::string MXNetPercentileType2String(int interpolation) { + using namespace op; + switch (interpolation) { + case percentile_enum::kLinear: + return "linear"; + case percentile_enum::kLower: + return "lower"; + case percentile_enum::kHigher: + return "higher"; + case percentile_enum::kMidpoint: + return "midpoint"; + case percentile_enum::kNearest: + return "nearest"; + default: + LOG(FATAL) << "Unknown type enum " << interpolation; + } + LOG(FATAL) << "should not reach here "; + return ""; +} + } // namespace mxnet diff --git a/src/api/operator/op_utils.h b/src/api/operator/op_utils.h index 4c577983c405..f41680df6fd6 100644 --- a/src/api/operator/op_utils.h +++ b/src/api/operator/op_utils.h @@ -29,6 +29,7 @@ namespace mxnet { std::string String2MXNetTypeWithBool(int dtype); +std::string MXNetPercentileType2String(int interpolation); } // namespace mxnet diff --git a/src/operator/numpy/np_percentile_op-inl.h b/src/operator/numpy/np_percentile_op-inl.h index 130e2662532d..80d275f8872c 100644 --- a/src/operator/numpy/np_percentile_op-inl.h +++ b/src/operator/numpy/np_percentile_op-inl.h @@ -33,6 +33,7 @@ #include "../operator_common.h" #include "../elemwise_op_common.h" #include "np_broadcast_reduce_op.h" +#include "../../api/operator/op_utils.h" namespace mxnet { namespace op { @@ -67,13 +68,12 @@ struct NumpyPercentileParam : public dmlc::Parameter { .describe("inqut q is a scalar"); } void SetAttrDict(std::unordered_map* dict) { - std::ostringstream axis_s, interpolation_s, keepdims_s, q_scalar_s; + std::ostringstream axis_s, keepdims_s, q_scalar_s; axis_s << axis; - interpolation_s << interpolation; keepdims_s << keepdims; q_scalar_s << q_scalar; (*dict)["axis"] = axis_s.str(); - (*dict)["interpolation"] = interpolation_s.str(); + (*dict)["interpolation"] = MXNetPercentileType2String(interpolation); (*dict)["keepdims"] = keepdims_s.str(); (*dict)["q_scalar"] = q_scalar_s.str(); }