Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new ffi
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Mar 10, 2020
1 parent 12f54f4 commit 3d8b736
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 19 deletions.
13 changes: 9 additions & 4 deletions src/api/operator/numpy/np_bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file np_bincount_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_bincount_op.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_bincount_op-inl.h"

Expand All @@ -38,17 +39,21 @@ MXNET_REGISTER_API("_npi.bincount")
param.minlength = args[2].operator int64_t();
param.has_weights = false;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
attrs.parsed = param; // remove std::move() ci error: trivially-copyable type
int num_inputs = 1;
attrs.parsed = param;
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyBincountParam>(op, &attrs, 1, inputs, &num_outputs, nullptr);
SetAttrDict<op::NumpyBincountParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
} else {
param.minlength = args[2].operator int64_t();
param.has_weights = true;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
attrs.parsed = param; // remove std::move() ci error: trivially-copyable type
int num_inputs = 2;
attrs.parsed = param;
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyBincountParam>(op, &attrs, 2, inputs, &num_outputs, nullptr);
SetAttrDict<op::NumpyBincountParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});
Expand Down
24 changes: 18 additions & 6 deletions src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file np_broadcast_reduce_op_boolean.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_broadcast_reduce_op_boolean.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/numpy/np_broadcast_reduce_op.h"

Expand All @@ -45,11 +47,16 @@ MXNET_REGISTER_API("_npi.all")
}
param.keepdims = args[2].operator bool();
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyReduceAxesBoolParam>(
op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyReduceAxesBoolParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

MXNET_REGISTER_API("_npi.any")
Expand All @@ -71,11 +78,16 @@ MXNET_REGISTER_API("_npi.any")
}
param.keepdims = args[2].operator bool();
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyReduceAxesBoolParam>(
op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyReduceAxesBoolParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
26 changes: 20 additions & 6 deletions src/api/operator/numpy/np_percentile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
* \file np_percentile_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_percentile_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/numpy/np_percentile_op-inl.h"

namespace mxnet {

inline int String2MXnetPercentileType(const std::string& s) {
inline int String2MXNetPercentileType(const std::string& s) {
using namespace op;
if (s == "linear") {
return percentile_enum::kLinear;
Expand Down Expand Up @@ -62,22 +64,34 @@ MXNET_REGISTER_API("_npi.percentile")
} else {
param.axis = Tuple<int>(args[2].operator ObjectRef());
}
param.interpolation = String2MXnetPercentileType(args[3].operator std::string());
param.interpolation = String2MXNetPercentileType(args[3].operator std::string());
param.keepdims = args[4].operator bool();
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLFloat) {
param.q_scalar = args[1].operator double();
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyPercentileParam>(op, &attrs, 1, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyPercentileParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
} else {
param.q_scalar = dmlc::nullopt;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
int num_inputs = 2;
attrs.parsed = std::move(param);
attrs.op = op;
auto ndoutputs = Invoke<op::NumpyPercentileParam>(op, &attrs, 2, inputs, &num_outputs, outputs);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
SetAttrDict<op::NumpyPercentileParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
}
});

Expand Down
21 changes: 21 additions & 0 deletions src/api/operator/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "op_utils.h"
#include <mxnet/base.h>
#include "../../operator/numpy/np_percentile_op-inl.h"

namespace mxnet {

Expand Down Expand Up @@ -52,4 +53,24 @@ std::string String2MXNetTypeWithBool(int dtype) {
return "";
}

std::string MXNetPercentileType2String(int interpolation) {
using namespace op;
switch (interpolation) {
case percentile_enum::kLinear:
return "linear";
case percentile_enum::kLower:
return "lower";
case percentile_enum::kHigher:
return "higher";
case percentile_enum::kMidpoint:
return "midpoint";
case percentile_enum::kNearest:
return "nearest";
default:
LOG(FATAL) << "Unknown type enum " << interpolation;
}
LOG(FATAL) << "should not reach here ";
return "";
}

} // namespace mxnet
1 change: 1 addition & 0 deletions src/api/operator/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
namespace mxnet {

std::string String2MXNetTypeWithBool(int dtype);
std::string MXNetPercentileType2String(int interpolation);

} // namespace mxnet

Expand Down
6 changes: 3 additions & 3 deletions src/operator/numpy/np_percentile_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "np_broadcast_reduce_op.h"
#include "../../api/operator/op_utils.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -67,13 +68,12 @@ struct NumpyPercentileParam : public dmlc::Parameter<NumpyPercentileParam> {
.describe("inqut q is a scalar");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, interpolation_s, keepdims_s, q_scalar_s;
std::ostringstream axis_s, keepdims_s, q_scalar_s;
axis_s << axis;
interpolation_s << interpolation;
keepdims_s << keepdims;
q_scalar_s << q_scalar;
(*dict)["axis"] = axis_s.str();
(*dict)["interpolation"] = interpolation_s.str();
(*dict)["interpolation"] = MXNetPercentileType2String(interpolation);
(*dict)["keepdims"] = keepdims_s.str();
(*dict)["q_scalar"] = q_scalar_s.str();
}
Expand Down

0 comments on commit 3d8b736

Please sign in to comment.