From 161127d90177cfb54d854045b01a294bf433f054 Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Thu, 5 Mar 2020 11:58:33 +0000 Subject: [PATCH] new ffi --- src/api/operator/numpy/np_bincount_op.cc | 11 ++++++---- .../numpy/np_broadcast_reduce_op_boolean.cc | 21 +++++++++++++------ src/api/operator/numpy/np_percentile_op.cc | 20 ++++++++++++++---- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/api/operator/numpy/np_bincount_op.cc b/src/api/operator/numpy/np_bincount_op.cc index e98c1a4a750d..c51ecd34c335 100644 --- a/src/api/operator/numpy/np_bincount_op.cc +++ b/src/api/operator/numpy/np_bincount_op.cc @@ -21,6 +21,7 @@ * \file np_bincount_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_bincount_op.cc */ +#include #include "../utils.h" #include "../../../operator/numpy/np_bincount_op-inl.h" @@ -38,17 +39,19 @@ MXNET_REGISTER_API("_npi.bincount") param.minlength = args[2].operator int64_t(); param.has_weights = false; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - attrs.parsed = param; // remove std::move() ci error: trivially-copyable type + attrs.parsed = param; attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } else { param.minlength = args[2].operator int64_t(); param.has_weights = true; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - attrs.parsed = param; // remove std::move() ci error: trivially-copyable type + attrs.parsed = param; attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } }); diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc index 32f70e8e2fc2..cba69fa2b1c1 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc @@ -21,6 +21,7 @@ * \file np_broadcast_reduce_op_boolean.cc * \brief Implementation of the API of functions in src/operator/numpy/np_broadcast_reduce_op_boolean.cc */ +#include #include "../utils.h" #include "../../../operator/numpy/np_broadcast_reduce_op.h" @@ -47,9 +48,13 @@ MXNET_REGISTER_API("_npi.all") NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke( - op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } }); MXNET_REGISTER_API("_npi.any") @@ -73,9 +78,13 @@ MXNET_REGISTER_API("_npi.any") NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke( - op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_percentile_op.cc b/src/api/operator/numpy/np_percentile_op.cc index 038e2d39c3f4..f978816a9534 100644 --- a/src/api/operator/numpy/np_percentile_op.cc +++ b/src/api/operator/numpy/np_percentile_op.cc @@ -21,6 +21,8 @@ * \file np_percentile_op.cc * \brief Implementation of the API of functions in src/operator/numpy/np_percentile_op.cc */ +#include +#include #include "../utils.h" #include "../../../operator/numpy/np_percentile_op-inl.h" @@ -69,15 +71,25 @@ MXNET_REGISTER_API("_npi.percentile") NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } } else { param.q_scalar = dmlc::nullopt; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; attrs.parsed = std::move(param); attrs.op = op; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs); - *ret = reinterpret_cast(ndoutputs[0]); + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } } });