Skip to content

Commit d0e17e5

Browse files
hanke580szha
andauthored
[Numpy] FFI: sort, argsort, vstack etc (apache#17857)
* * sort FFI * * argsort FFI * * vstack, row_stack FFI * * greater FFI * * inner FFI * multinomial FFI * rand FFI * randn FFI * * Fix input out of index and rscalar of greater * * Fix ndarray situation * * Fix sanity * fix lint * fix bugs * * Remove duplicate operator (greater) * * Fix Tuple downcast Error (Only Integer) * Fix segmentation fault(pointer) Co-authored-by: Sheng Zha <[email protected]>
1 parent 5c50475 commit d0e17e5

File tree

9 files changed

+232
-17
lines changed

9 files changed

+232
-17
lines changed

benchmark/python/ffi/benchmark_ffi.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def prepare_workloads():
6262
OpArgMngr.add_workload("nan_to_num", pool['2x2'])
6363
OpArgMngr.add_workload("tri", 2, 3, 4)
6464
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
65-
OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2'])
6665
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
6766
OpArgMngr.add_workload("random.shuffle", pool['3'])
6867
OpArgMngr.add_workload("equal", pool['2x2'], pool['2x2'])
@@ -100,11 +99,14 @@ def prepare_workloads():
10099
OpArgMngr.add_workload("trace", pool['2x2'])
101100
OpArgMngr.add_workload("transpose", pool['2x2'])
102101
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
102+
OpArgMngr.add_workload("vstack", (pool['3x3'], pool['3x3'], pool['3x3']))
103103
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
104104
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
105105
OpArgMngr.add_workload("atleast_1d", pool['2'], pool['2x2'])
106106
OpArgMngr.add_workload("atleast_2d", pool['2'], pool['2x2'])
107107
OpArgMngr.add_workload("atleast_3d", pool['2'], pool['2x2'])
108+
OpArgMngr.add_workload("argsort", pool['3x2'], axis=-1)
109+
OpArgMngr.add_workload("sort", pool['3x2'], axis=-1)
108110
OpArgMngr.add_workload("indices", dimensions=(1, 2, 3))
109111
OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2'])
110112
OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2'])
@@ -115,6 +117,10 @@ def prepare_workloads():
115117
OpArgMngr.add_workload("power", pool['2x2'], pool['2x2'])
116118
OpArgMngr.add_workload("lcm", pool['2x2'].astype('int32'), pool['2x2'].astype('int32'))
117119
OpArgMngr.add_workload("diff", pool['2x2'], n=1, axis=-1)
120+
OpArgMngr.add_workload("inner", pool['2x2'], pool['2x2'])
121+
OpArgMngr.add_workload("random.multinomial", n=2, pvals=[1/6.]*6, size=(2,2))
122+
OpArgMngr.add_workload("random.rand", 3, 2)
123+
OpArgMngr.add_workload("random.randn", 2, 2)
118124
OpArgMngr.add_workload("nonzero", pool['2x2'])
119125
OpArgMngr.add_workload("tril", pool['2x2'], k=0)
120126
OpArgMngr.add_workload("random.choice", pool['2'], size=(2, 2))
@@ -144,9 +150,6 @@ def prepare_workloads():
144150
OpArgMngr.add_workload("random.logistic", loc=2, scale=2, size=(2,2))
145151
OpArgMngr.add_workload("random.gumbel", loc=2, scale=2, size=(2,2))
146152
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
147-
OpArgMngr.add_workload("fmax", pool['2x2'], pool['2x2'])
148-
OpArgMngr.add_workload("fmin", pool['2x2'], pool['2x2'])
149-
OpArgMngr.add_workload("fmod", pool['2x2'], pool['2x2'])
150153
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
151154
OpArgMngr.add_workload('squeeze', pool['2x2'], axis=None)
152155
OpArgMngr.add_workload("pad", pool['2x2'], pad_width=((1,2),(1,2)), mode="constant")

python/mxnet/ndarray/numpy/_op.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,7 @@ def argsort(a, axis=-1, kind=None, order=None):
16201620
if order is not None:
16211621
raise NotImplementedError("order not supported here")
16221622

1623-
return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')
1623+
return _api_internal.argsort(a, axis, True, 'int64')
16241624

16251625

16261626
@set_module('mxnet.ndarray.numpy')
@@ -1664,7 +1664,7 @@ def sort(a, axis=-1, kind=None, order=None):
16641664
"""
16651665
if order is not None:
16661666
raise NotImplementedError("order not supported here")
1667-
return _npi.sort(data=a, axis=axis, is_ascend=True)
1667+
return _api_internal.sort(a, axis, True)
16681668

16691669

16701670
@set_module('mxnet.ndarray.numpy')
@@ -4581,7 +4581,7 @@ def get_list(arrays):
45814581
return [arr for arr in arrays]
45824582

45834583
arrays = get_list(arrays)
4584-
return _npi.vstack(*arrays)
4584+
return _api_internal.vstack(*arrays)
45854585

45864586

45874587
@set_module('mxnet.ndarray.numpy')
@@ -4626,7 +4626,7 @@ def get_list(arrays):
46264626
return [arr for arr in arrays]
46274627

46284628
arrays = get_list(arrays)
4629-
return _npi.vstack(*arrays)
4629+
return _api_internal.vstack(*arrays)
46304630

46314631

46324632
@set_module('mxnet.ndarray.numpy')

python/mxnet/ndarray/numpy/random.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ...context import current_context
2222
from . import _internal as _npi
2323
from . import _api_internal
24-
from ..ndarray import NDArray
2524

2625

2726
__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "multivariate_normal",
@@ -331,14 +330,11 @@ def multinomial(n, pvals, size=None):
331330
>>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3])
332331
array([32, 68])
333332
"""
334-
if isinstance(pvals, NDArray):
335-
return _npi.multinomial(pvals, pvals=None, n=n, size=size)
336-
else:
337-
if isinstance(pvals, np.ndarray):
338-
raise ValueError('numpy ndarray is not supported!')
339-
if any(isinstance(i, list) for i in pvals):
340-
raise ValueError('object too deep for desired array')
341-
return _npi.multinomial(n=n, pvals=pvals, size=size)
333+
if isinstance(pvals, np.ndarray):
334+
raise ValueError('numpy ndarray is not supported!')
335+
if any(isinstance(i, list) for i in pvals):
336+
raise ValueError('object too deep for desired array')
337+
return _api_internal.multinomial(n, pvals, size)
342338

343339

344340
def rayleigh(scale=1.0, size=None, ctx=None, out=None):

src/api/operator/numpy/np_matrix_op.cc

+20
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,24 @@ MXNET_REGISTER_API("_npi.tril_indices")
615615
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
616616
});
617617

618+
MXNET_REGISTER_API("_npi.vstack")
619+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
620+
using namespace runtime;
621+
const nnvm::Op* op = Op::Get("_npi_vstack");
622+
nnvm::NodeAttrs attrs;
623+
op::NumpyVstackParam param;
624+
param.num_args = args.size();
625+
626+
attrs.parsed = param;
627+
attrs.op = op;
628+
SetAttrDict<op::NumpyVstackParam>(&attrs);
629+
int num_outputs = 0;
630+
std::vector<NDArray*> inputs_vec(args.size(), nullptr);
631+
for (int i = 0; i < args.size(); ++i) {
632+
inputs_vec[i] = args[i].operator mxnet::NDArray*();
633+
}
634+
NDArray** inputs = inputs_vec.data();
635+
auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr);
636+
*ret = ndoutputs[0];
637+
});
618638
} // namespace mxnet
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_ordering_op.cc
22+
* \brief Implementation of the API of functions in src/operator/tensor/ordering_op.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include "../utils.h"
27+
#include "../../../operator/tensor/ordering_op-inl.h"
28+
29+
namespace mxnet {
30+
31+
MXNET_REGISTER_API("_npi.sort")
32+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
33+
using namespace runtime;
34+
const nnvm::Op* op = Op::Get("_npi_sort");
35+
nnvm::NodeAttrs attrs;
36+
op::SortParam param;
37+
38+
if (args[1].type_code() == kNull) {
39+
param.axis = dmlc::nullopt;
40+
} else {
41+
param.axis = args[1].operator int();
42+
}
43+
param.is_ascend = true;
44+
45+
attrs.parsed = std::move(param);
46+
attrs.op = op;
47+
48+
int num_inputs = 1;
49+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
50+
51+
int num_outputs = 0;
52+
SetAttrDict<op::SortParam>(&attrs);
53+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
54+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
55+
});
56+
57+
MXNET_REGISTER_API("_npi.argsort")
58+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
59+
using namespace runtime;
60+
const nnvm::Op* op = Op::Get("_npi_argsort");
61+
nnvm::NodeAttrs attrs;
62+
op::ArgSortParam param;
63+
64+
if (args[1].type_code() == kNull) {
65+
param.axis = dmlc::nullopt;
66+
} else {
67+
param.axis = args[1].operator int();
68+
}
69+
param.is_ascend = true;
70+
if (args[3].type_code() == kNull) {
71+
param.dtype = mshadow::kFloat32;
72+
} else {
73+
param.dtype = String2MXNetTypeWithBool(args[3].operator std::string());
74+
}
75+
76+
attrs.parsed = std::move(param);
77+
attrs.op = op;
78+
79+
int num_inputs = 1;
80+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
81+
82+
int num_outputs = 0;
83+
SetAttrDict<op::ArgSortParam>(&attrs);
84+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
85+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
86+
});
87+
88+
} // namespace mxnet
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_multinomial_op.cc
22+
* \brief Implementation of the API of functions in src/operator/numpy/random/np_multinomial_op.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include <vector>
27+
#include "../../utils.h"
28+
#include "../../../../operator/numpy/random/np_multinomial_op.h"
29+
30+
namespace mxnet {
31+
32+
MXNET_REGISTER_API("_npi.multinomial")
33+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
34+
using namespace runtime;
35+
const nnvm::Op* op = Op::Get("_npi_multinomial");
36+
nnvm::NodeAttrs attrs;
37+
op::NumpyMultinomialParam param;
38+
NDArray** inputs = new NDArray*[1]();
39+
int num_inputs = 0;
40+
41+
// parse int
42+
param.n = args[0].operator int();
43+
44+
// parse pvals
45+
if (args[1].type_code() == kNull) {
46+
param.pvals = dmlc::nullopt;
47+
} else if (args[1].type_code() == kNDArrayHandle) {
48+
param.pvals = dmlc::nullopt;
49+
inputs[0] = args[1].operator mxnet::NDArray*();
50+
num_inputs = 1;
51+
} else {
52+
param.pvals = Obj2Tuple<double, Float>(args[1].operator ObjectRef());
53+
}
54+
55+
// parse size
56+
if (args[2].type_code() == kNull) {
57+
param.size = dmlc::nullopt;
58+
} else {
59+
if (args[2].type_code() == kDLInt) {
60+
param.size = mxnet::Tuple<int>(1, args[2].operator int64_t());
61+
} else {
62+
param.size = mxnet::Tuple<int>(args[2].operator ObjectRef());
63+
}
64+
}
65+
66+
attrs.parsed = std::move(param);
67+
attrs.op = op;
68+
SetAttrDict<op::NumpyMultinomialParam>(&attrs);
69+
inputs = num_inputs == 0 ? nullptr : inputs;
70+
int num_outputs = 0;
71+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
72+
*ret = ndoutputs[0];
73+
});
74+
75+
} // namespace mxnet

src/operator/numpy/np_matrix_op-inl.h

+5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
6161
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
6262
.describe("Number of inputs to be vstacked.");
6363
}
64+
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
65+
std::ostringstream num_args_s;
66+
num_args_s << num_args;
67+
(*dict)["num_args"] = num_args_s.str();
68+
}
6469
};
6570

6671
struct NumpyColumnStackParam : public dmlc::Parameter<NumpyColumnStackParam> {

src/operator/numpy/random/np_multinomial_op.h

+10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <mxnet/operator_util.h>
2929
#include <vector>
30+
#include <string>
3031
#include "../../mshadow_op.h"
3132
#include "../../mxnet_op.h"
3233
#include "../../operator_common.h"
@@ -55,6 +56,15 @@ struct NumpyMultinomialParam : public dmlc::Parameter<NumpyMultinomialParam> {
5556
"e.g., (m, n, k), then m * n * k samples are drawn. "
5657
"Default is None, in which case a single value is returned.");
5758
}
59+
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
60+
std::ostringstream n_s, pvals_s, size_s;
61+
n_s << n;
62+
pvals_s << pvals;
63+
size_s << size;
64+
(*dict)["n"] = n_s.str();
65+
(*dict)["pvals"] = pvals_s.str();
66+
(*dict)["size"] = size_s.str();
67+
}
5868
};
5969

6070
inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs,

src/operator/tensor/ordering_op-inl.h

+18
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
#include <mshadow/tensor.h>
3131
#include <algorithm>
3232
#include <vector>
33+
#include <string>
3334
#include <type_traits>
3435
#include "../mshadow_op.h"
3536
#include "../elemwise_op_common.h"
3637
#include "./sort_op.h"
3738
#include "./indexing_op.h"
39+
#include "../../api/operator/op_utils.h"
3840

3941
namespace mshadow {
4042
template<typename xpu, int src_dim, typename DType, int dst_dim>
@@ -105,6 +107,13 @@ struct SortParam : public dmlc::Parameter<SortParam> {
105107
DMLC_DECLARE_FIELD(is_ascend).set_default(true)
106108
.describe("Whether to sort in ascending or descending order.");
107109
}
110+
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
111+
std::ostringstream axis_s, is_ascend_s;
112+
axis_s << axis;
113+
is_ascend_s << is_ascend;
114+
(*dict)["axis"] = axis_s.str();
115+
(*dict)["is_ascend_s"] = is_ascend_s.str();
116+
}
108117
};
109118

110119
struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
@@ -130,6 +139,15 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
130139
" \"both\". An error will be raised if the selected data type cannot precisely "
131140
"represent the indices.");
132141
}
142+
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
143+
std::ostringstream axis_s, is_ascend_s, dtype_s;
144+
axis_s << axis;
145+
is_ascend_s << is_ascend;
146+
dtype_s << dtype;
147+
(*dict)["axis"] = axis_s.str();
148+
(*dict)["is_ascend_s"] = is_ascend_s.str();
149+
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
150+
}
133151
};
134152

135153
inline void ParseTopKParam(const TShape& src_shape,

0 commit comments

Comments
 (0)