Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new ffi
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Mar 5, 2020
1 parent 12f54f4 commit 161127d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
11 changes: 7 additions & 4 deletions src/api/operator/numpy/np_bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file np_bincount_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_bincount_op.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_bincount_op-inl.h"

Expand All @@ -38,17 +39,19 @@ MXNET_REGISTER_API("_npi.bincount")
param.minlength = args[2].operator int64_t();
param.has_weights = false;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
attrs.parsed = param; // remove std::move() ci error: trivially-copyable type
attrs.parsed = param;
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyBincountParam>(op, &attrs, 1, inputs, &num_outputs, nullptr);
SetAttrDict<op::NumpyBincountParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
} else {
param.minlength = args[2].operator int64_t();
param.has_weights = true;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
attrs.parsed = param; // remove std::move() ci error: trivially-copyable type
attrs.parsed = param;
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyBincountParam>(op, &attrs, 2, inputs, &num_outputs, nullptr);
SetAttrDict<op::NumpyBincountParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});
Expand Down
21 changes: 15 additions & 6 deletions src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file np_broadcast_reduce_op_boolean.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_broadcast_reduce_op_boolean.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_broadcast_reduce_op.h"

Expand All @@ -47,9 +48,13 @@ MXNET_REGISTER_API("_npi.all")
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyReduceAxesBoolParam>(
op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyReduceAxesBoolParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

MXNET_REGISTER_API("_npi.any")
Expand All @@ -73,9 +78,13 @@ MXNET_REGISTER_API("_npi.any")
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyReduceAxesBoolParam>(
op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyReduceAxesBoolParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
20 changes: 16 additions & 4 deletions src/api/operator/numpy/np_percentile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file np_percentile_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_percentile_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/numpy/np_percentile_op-inl.h"

Expand Down Expand Up @@ -69,15 +71,25 @@ MXNET_REGISTER_API("_npi.percentile")
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyPercentileParam>(op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyPercentileParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
} else {
param.q_scalar = dmlc::nullopt;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyPercentileParam>(op, &attrs, 2, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyPercentileParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
}
});

Expand Down

0 comments on commit 161127d

Please sign in to comment.