From e2c924be8fde8c789597170ede66c56187ef3924 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Mon, 28 Oct 2019 16:24:05 +0800 Subject: [PATCH] [NumPy][Operator] NumPy operator `may_share_memory` and `shares_memory` (#16533) * init * finish & fix bug of 'take' * fix bug * add dispatch --- python/mxnet/ndarray/numpy/_op.py | 76 +++++++++++++++++- python/mxnet/numpy/multiarray.py | 79 ++++++++++++++++++- python/mxnet/numpy_dispatch_protocol.py | 4 +- python/mxnet/symbol/numpy/_symbol.py | 41 +++++++++- src/operator/numpy/np_memory_op.cc | 62 +++++++++++++++ src/operator/numpy/np_memory_op.cu | 34 ++++++++ src/operator/numpy/np_memory_op.h | 75 ++++++++++++++++++ .../unittest/test_numpy_interoperability.py | 15 ++++ tests/python/unittest/test_numpy_op.py | 27 ++++++- 9 files changed, 406 insertions(+), 7 deletions(-) create mode 100644 src/operator/numpy/np_memory_op.cc create mode 100644 src/operator/numpy/np_memory_op.cu create mode 100644 src/operator/numpy/np_memory_op.h diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index fdb9694146b5..84aa4a1572d9 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory'] @set_module('mxnet.ndarray.numpy') @@ -4909,3 +4909,77 @@ def nonzero(a): """ out = _npi.nonzero(a).transpose() return tuple([out[i] for i in range(len(out))]) + + +@set_module('mxnet.ndarray.numpy') +def shares_memory(a, b, max_work=None): + """ + Determine if two arrays share memory + + Parameters + ---------- + a, b : ndarray + Input arrays + + Returns + ------- + out : bool + + See Also + -------- + may_share_memory + + Examples + -------- + >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) + False + + This function differs from the original `numpy.shares_memory + `_ in + the following way(s): + + - Does not support `max_work`, it is a dummy argument + - Actually it is same as `may_share_memory` in MXNet DeepNumPy + """ + return _npi.share_memory(a, b).item() + + +@set_module('mxnet.ndarray.numpy') +def may_share_memory(a, b, max_work=None): + """ + Determine if two arrays might share memory + + A return of True does not necessarily mean that the two arrays + share any element. It just means that they *might*. + + Only the memory bounds of a and b are checked by default. + + Parameters + ---------- + a, b : ndarray + Input arrays + + Returns + ------- + out : bool + + See Also + -------- + shares_memory + + Examples + -------- + >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) + False + >>> x = np.zeros([3, 4]) + >>> np.may_share_memory(x[:,0], x[:,1]) + True + + This function differs from the original `numpy.may_share_memory + `_ in + the following way(s): + + - Does not support `max_work`, it is a dummy argument + - Actually it is same as `shares_memory` in MXNet DeepNumPy + """ + return _npi.share_memory(a, b).item() diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5c9de8194a74..ef88638c857e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -55,7 +55,8 @@ 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero'] + 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', + 'may_share_memory'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -1330,7 +1331,7 @@ def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-d The arguments are the same as for :py:func:`take`, with this array as data. """ - take(self, indices, axis, mode=mode) + return take(self, indices, axis, mode=mode) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -6900,3 +6901,77 @@ def nonzero(a): (array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64)) """ return _mx_nd_np.nonzero(a) + + +@set_module('mxnet.numpy') +def shares_memory(a, b, max_work=None): + """ + Determine if two arrays share memory + + Parameters + ---------- + a, b : ndarray + Input arrays + + Returns + ------- + out : bool + + See Also + -------- + may_share_memory + + Examples + -------- + >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) + False + + This function differs from the original `numpy.shares_memory + `_ in + the following way(s): + + - Does not support `max_work`, it is a dummy argument + - Actually it is same as `may_share_memory` in MXNet DeepNumPy + """ + return _mx_nd_np.shares_memory(a, b, max_work) + + +@set_module('mxnet.numpy') +def may_share_memory(a, b, max_work=None): + """ + Determine if two arrays might share memory + + A return of True does not necessarily mean that the two arrays + share any element. It just means that they *might*. + + Only the memory bounds of a and b are checked by default. + + Parameters + ---------- + a, b : ndarray + Input arrays + + Returns + ------- + out : bool + + See Also + -------- + shares_memory + + Examples + -------- + >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) + False + >>> x = np.zeros([3, 4]) + >>> np.may_share_memory(x[:,0], x[:,1]) + True + + This function differs from the original `numpy.may_share_memory + `_ in + the following way(s): + + - Does not support `max_work`, it is a dummy argument + - Actually it is same as `shares_memory` in MXNet DeepNumPy + """ + return _mx_nd_np.may_share_memory(a, b, max_work) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index cec2f245a5e1..6a5f166a70eb 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -127,7 +127,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'tril', 'meshgrid', 'outer', - 'einsum' + 'einsum', + 'shares_memory', + 'may_share_memory', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index ddf2feb30b18..2e6d41446930 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide'] + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory'] def _num_outputs(sym): @@ -4590,4 +4590,43 @@ def einsum(*operands, **kwargs): return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg)) +@set_module('mxnet.symbol.numpy') +def shares_memory(a, b, max_work=None): + """ + Determine if two arrays share memory + + Parameters + ---------- + a, b : _Symbol + Input arrays + + Returns + ------- + out : _Symbol + """ + return _npi.share_memory(a, b) + + +@set_module('mxnet.symbol.numpy') +def may_share_memory(a, b, max_work=None): + """ + Determine if two arrays might share memory + + A return of True does not necessarily mean that the two arrays + share any element. It just means that they *might*. + + Only the memory bounds of a and b are checked by default. + + Parameters + ---------- + a, b : _Symbol + Input arrays + + Returns + ------- + out : _Symbol + """ + return _npi.share_memory(a, b) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_memory_op.cc b/src/operator/numpy/np_memory_op.cc new file mode 100644 index 000000000000..522998e9c45d --- /dev/null +++ b/src/operator/numpy/np_memory_op.cc @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_memory_op.cc + */ + +#include "./np_memory_op.h" + +namespace mxnet { +namespace op { + +inline bool NumpyShareMemoryType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return out_attrs->at(0) != -1; +} + +inline bool NumpyShareMemoryShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, -1)); + return true; +} + +NNVM_REGISTER_OP(_npi_share_memory) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", NumpyShareMemoryShape) +.set_attr("FInferType", NumpyShareMemoryType) +.set_attr("FCompute", NumpyShareMemoryCompute) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_argument("b", "NDArray-or-Symbol", "Second input"); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_memory_op.cu b/src/operator/numpy/np_memory_op.cu new file mode 100644 index 000000000000..61bf70d26b8e --- /dev/null +++ b/src/operator/numpy/np_memory_op.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_memory_op.cu + */ + +#include "./np_memory_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_share_memory) +.set_attr("FCompute", NumpyShareMemoryCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_memory_op.h b/src/operator/numpy/np_memory_op.h new file mode 100644 index 000000000000..2c0f3f063f65 --- /dev/null +++ b/src/operator/numpy/np_memory_op.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_memory_op.h + * \brief Function definition of numpy memory op + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_ + +#include +#include +#include +#include "../operator_common.h" + +namespace mxnet { +namespace op { + +template +void NumpyShareMemoryCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + Tensor outdata = outputs[0].FlatTo1D(s); + + if (a.Size() == 0 || b.Size() == 0) { + ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, false); + return; + } + MSHADOW_TYPE_SWITCH_WITH_BOOL(a.type_flag_, AType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, BType, { + uint64_t start1 = reinterpret_cast(a.dptr_); + uint64_t end1 = start1 + a.Size() * sizeof(AType); + uint64_t start2 = reinterpret_cast(b.dptr_); + uint64_t end2 = start2 + b.Size() * sizeof(BType); + if (!(start1 < end2 && start2 < end1 && start1 < end1 && start2 < end2)) { + ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, false); + } else { + ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, true); + } + }); + }); + return; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_ diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 860fecc5cda0..624fc0a107b0 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1234,6 +1234,8 @@ def check_interoperability(op_list): for name in op_list: if name in _TVM_OPS and not is_op_runnable(): continue + if name in ['shares_memory', 'may_share_memory']: # skip list + continue print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) assert workloads is not None, 'Workloads for operator `{}` has not been ' \ @@ -1243,6 +1245,19 @@ def check_interoperability(op_list): _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) +@with_seed() +@use_np +@with_array_function_protocol +def test_np_memory_array_function(): + ops = [_np.shares_memory, _np.may_share_memory] + for op in ops: + data_mx = np.zeros([13, 21, 23, 22], dtype=np.float32) + data_np = _np.zeros([13, 21, 23, 22], dtype=np.float32) + assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:]) + assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7]) + assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0))) + + @with_seed() @use_np @with_array_function_protocol diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 391a07411b15..bfe6c3d43b50 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3515,7 +3515,7 @@ def dbg(name, data): for config in configs: for optimize in [False, True]: rtol = 1e-2 if dtype == 'float16' else 1e-3 - atol = 1e-4 if dtype == 'float16' else 1e-5 + atol = 1e-4 if dtype == 'float16' else 1e-5 (subscripts, operands, get_grad) = config test_einsum = TestEinsum(subscripts, optimize) if hybridize: @@ -3556,7 +3556,7 @@ def dbg(name, data): for config in configs: (subscripts, operands) = config rtol = 1e-2 if dtype == 'float16' else 1e-3 - atol = 1e-4 if dtype == 'float16' else 1e-5 + atol = 1e-4 if dtype == 'float16' else 1e-5 grad = [] x_np = [] for shape in operands: @@ -3741,6 +3741,29 @@ def hybrid_forward(self, F, a, *args, **kwargs): assert_almost_equal(npx_out.asnumpy(), expected_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_share_memory(): + ops = [np.shares_memory, np.may_share_memory] + # reshape not support boolean types + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + for op in ops: + for dt in dtypes: + x = np.zeros([13, 21, 23, 22], dtype=dt) + assert not op(x[0,:,:,:], x[1,:,:,:]) + assert not op(x[2,:,:,:], x[3,:,:,:]) + assert not op(x[2:5,0,0,0], x[3:4,0,0,0]) + assert not op(x[2:5,0,0,0], x[4:7,0,0,0]) + assert op(x[0,0,0,2:5], x[0,0,0,3:4]) + assert op(x[0,6,0,2:5], x[0,6,0,4:7]) + assert not op(x[0,5,0,2:5], x[0,6,0,4:7]) + + for adt in dtypes: + assert not op(x, np.ones((5, 0), dtype=adt)) + assert not op(np.ones((5, 0), dtype=adt), x) + assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt)) + + if __name__ == '__main__': import nose nose.runmodule()