Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* vstack, row_stack FFI
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Mar 23, 2020
1 parent 5245d9e commit 9ecba80
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 3 deletions.
1 change: 1 addition & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def prepare_workloads():
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("vstack", (pool['3x3'], pool['3x3'], pool['3x3']))
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
OpArgMngr.add_workload("indices", dimensions=(1, 2, 3))
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4188,7 +4188,7 @@ def get_list(arrays):
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.vstack(*arrays)
return _api_internal.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4233,7 +4233,7 @@ def get_list(arrays):
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.vstack(*arrays)
return _api_internal.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
Expand Down
20 changes: 20 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,24 @@ MXNET_REGISTER_API("_npi.rot90")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.vstack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_vstack");
nnvm::NodeAttrs attrs;
op::NumpyVstackParam param;
param.num_args = args.size();

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyVstackParam>(&attrs);
int num_outputs = 0;
std::vector<NDArray*> inputs;
for (int i = 0; i < param.num_args; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be vstacked.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s;
num_args_s << num_args;
(*dict)["num_args"] = num_args_s.str();
}
};

struct NumpyColumnStackParam : public dmlc::Parameter<NumpyColumnStackParam> {
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "../elemwise_op_common.h"
#include "./sort_op.h"
#include "./indexing_op.h"
#include "../../api/operator/op_utils.h"

namespace mshadow {
template<typename xpu, int src_dim, typename DType, int dst_dim>
Expand Down Expand Up @@ -145,7 +146,7 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
dtype_s << dtype;
(*dict)["axis"] = axis_s.str();
(*dict)["is_ascend_s"] = is_ascend_s.str();
(*dict)["dtype"] = String2MXNetTypeWithBool(dtype);
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
}
};

Expand Down

0 comments on commit 9ecba80

Please sign in to comment.