From 39bf8c231957cc2fe2173adfae61d688fc8e027c Mon Sep 17 00:00:00 2001 From: Ke Han Date: Mon, 23 Mar 2020 13:44:57 +0800 Subject: [PATCH] * vstack, row_stack FFI --- benchmark/python/ffi/benchmark_ffi.py | 3 +++ python/mxnet/ndarray/numpy/_op.py | 4 ++-- src/api/operator/numpy/np_matrix_op.cc | 20 ++++++++++++++++++++ src/api/operator/numpy/np_ordering_op.cc | 4 ++-- src/operator/numpy/np_matrix_op-inl.h | 5 +++++ src/operator/tensor/ordering_op-inl.h | 3 ++- 6 files changed, 34 insertions(+), 5 deletions(-) diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 96d8e1d6658f..dbfea8e7783b 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -60,8 +60,11 @@ def prepare_workloads(): OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("linalg.svd", pool['3x3']) OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1) + OpArgMngr.add_workload("vstack", (pool['3x3'], pool['3x3'], pool['3x3'])) OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1) OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1) + OpArgMngr.add_workload("argsort", pool['3x2'], axis=-1) + OpArgMngr.add_workload("sort", pool['3x2'], axis=-1) OpArgMngr.add_workload("indices", dimensions=(1, 2, 3)) OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2']) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index ee352fffbf27..a312b0db0055 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -4218,7 +4218,7 @@ def get_list(arrays): return [arr for arr in arrays] arrays = get_list(arrays) - return _npi.vstack(*arrays) + return _api_internal.vstack(*arrays) @set_module('mxnet.ndarray.numpy') @@ -4263,7 +4263,7 @@ def get_list(arrays): return [arr for arr in arrays] arrays = get_list(arrays) - return _npi.vstack(*arrays) + return _api_internal.vstack(*arrays) @set_module('mxnet.ndarray.numpy') diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index ae8421ac4010..754cee8ddc0b 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -191,4 +191,24 @@ MXNET_REGISTER_API("_npi.diag_indices_from") *ret = ndoutputs[0]; }); +MXNET_REGISTER_API("_npi.vstack") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_vstack"); + nnvm::NodeAttrs attrs; + op::NumpyVstackParam param; + param.num_args = args.size(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + std::vector inputs; + for (int i = 0; i < param.num_args; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + } // namespace mxnet diff --git a/src/api/operator/numpy/np_ordering_op.cc b/src/api/operator/numpy/np_ordering_op.cc index bf9e8033448a..ec0db28b4f9a 100644 --- a/src/api/operator/numpy/np_ordering_op.cc +++ b/src/api/operator/numpy/np_ordering_op.cc @@ -66,11 +66,11 @@ MXNET_REGISTER_API("_npi.argsort") } else { param.axis = args[1].operator int(); } - param.is_ascend = args[2].operator bool(); + param.is_ascend = true; if (args[3].type_code() == kNull) { param.dtype = mshadow::kFloat32; } else { - param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); } attrs.parsed = std::move(param); diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 2e48596cee9c..b6dfe6e8fca9 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -55,6 +55,11 @@ struct NumpyVstackParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs to be vstacked."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream num_args_s; + num_args_s << num_args; + (*dict)["num_args"] = num_args_s.str(); + } }; struct NumpyColumnStackParam : public dmlc::Parameter { diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index f7226ac07d9f..a7351a82bed0 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -36,6 +36,7 @@ #include "../elemwise_op_common.h" #include "./sort_op.h" #include "./indexing_op.h" +#include "../../api/operator/op_utils.h" namespace mshadow { template @@ -145,7 +146,7 @@ struct ArgSortParam : public dmlc::Parameter { dtype_s << dtype; (*dict)["axis"] = axis_s.str(); (*dict)["is_ascend_s"] = is_ascend_s.str(); - (*dict)["dtype"] = String2MXNetTypeWithBool(dtype); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); } };