From df4125aa1d4ef013e68e6adf3738dabdb1b52865 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Tue, 8 Oct 2019 08:39:56 +0800 Subject: [PATCH 1/5] update NEWS.md and README.md (#16385) --- NEWS.md | 24 ++++++++++++++++++++++++ README.md | 1 + 2 files changed, 25 insertions(+) diff --git a/NEWS.md b/NEWS.md index 8ea31334b6d1..b49c119d87cb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,7 @@ MXNet Change Log ================ - [MXNet Change Log](#mxnet-change-log) + * [1.5.1](#151) * [1.5.0](#150) + [New Features](#new-features) - [Automatic Mixed Precision(experimental)](#automatic-mixed-precision-experimental-) @@ -199,6 +200,29 @@ MXNet Change Log * [v0.7](#v07) * [v0.5 (initial release)](#v05--initial-release-) +## 1.5.1 +Apache MXNet (incubating) 1.5.1 is a maintenance release incorporating important bug fixes and important performance improvements. All users of Apache MXNet (incubating) 1.5.0 are advised to upgrade. You can install Apache MXNet (incubating) 1.5.1 at the usual place. Please review these Release Notes to learn the bug fixes. + +### Bug-fixes +* add deconv in TRT subgraph (#15666) (#16043) +* Update TRT tutorial with new APIs (#16044) +* Fix _copy_to on MKLDNN backend (#15637) (#15803) +* Benchmark doc fix (#15769) (#16029) +* remove Julia cat image for license issue (#15964) (#16026) +* added check for empty params file and unknown param (not arg/aux) (#15917) +* fix license issues (#15806) (#15860) +* prevent TRT_Logger to be destroyed before TRT engine (#14898) (#15877) +* [MXNET-1086] added sub and mul to ONNX->TensorRT conversion (#15344) (#15875) +* handle fix_gamma in tensorrt subgraph conversion correctly (#15645) (#15874) +* fix LinearRegressionOutput with empty label (#15620) (#15873) +* [v1.5.x] [MKLDNN] Independent gradients requests check with respect to weights… (#15805) +* fix dropout mask output (#15697) (#15804) +* fix fp32 flatten issue (#15351) (#15802) +* Clojure package remove source images (#15828) +* changed constructor args (#15601) (#15827) +* Add MKLDNN 4c layout to fix gluoncv se_resnext101_64x4d (#15692) (#15801) +* Fix the bug of `MXEnginePushAsyncND` and `MXEnginePushSyncND` (#15751) (#15792) + ## 1.5.0 ### New Features diff --git a/README.md b/README.md index 413f8c492c45..e8992393f246 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ How to Contribute What's New ---------- +* [Version 1.5.1 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.5.1) - MXNet 1.5.1 Patch Release. * [Version 1.5.0 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.5.0) - MXNet 1.5.0 Release. * [Version 1.4.1 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.4.1) - MXNet 1.4.1 Patch Release. * [Version 1.4.0 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.4.0) - MXNet 1.4.0 Release. From 0bace559cd6d70dbbfe3ca6545e56c98c92116b1 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 8 Oct 2019 02:34:37 +0000 Subject: [PATCH 2/5] fix choice signature --- python/mxnet/ndarray/numpy/random.py | 4 +--- python/mxnet/numpy/random.py | 4 ++-- python/mxnet/symbol/numpy/random.py | 5 +---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 3ad9f56eca48..9e401695709d 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -245,7 +245,7 @@ def multinomial(n, pvals, size=None): return _npi.multinomial(n=n, pvals=pvals, size=size) -def choice(a, size=None, replace=True, p=None, **kwargs): +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): """Generates a random sample from a given 1-D array Parameters @@ -298,10 +298,8 @@ def choice(a, size=None, replace=True, p=None, **kwargs): array([2, 3, 0]) """ from ...numpy import ndarray as np_ndarray - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - out = kwargs.pop('out', None) if size == (): size = None if isinstance(a, np_ndarray): diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 6d845b617261..746ce99c6d9c 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -178,7 +178,7 @@ def multinomial(n, pvals, size=None, **kwargs): return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs) -def choice(a, size=None, replace=True, p=None, **kwargs): +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): """Generates a random sample from a given 1-D array Parameters @@ -230,4 +230,4 @@ def choice(a, size=None, replace=True, p=None, **kwargs): >>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) array([2, 3, 0]) """ - return _mx_nd_np.random.choice(a, size, replace, p, **kwargs) + return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 2398b9ce759b..84cc5704b138 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -190,7 +190,7 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def choice(a, size=None, replace=True, p=None, **kwargs): +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): """Generates a random sample from a given 1-D array Parameters @@ -243,13 +243,10 @@ def choice(a, size=None, replace=True, p=None, **kwargs): array([2, 3, 0]) """ from ._symbol import _Symbol as np_symbol - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - out = kwargs.pop('out', None) if size == (): size = None - if isinstance(a, np_symbol): ctx = None if p is None: From ec766d55b26fe734d41fd6b9e6b81c404d77a947 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 8 Oct 2019 05:31:54 +0000 Subject: [PATCH 3/5] add raise test for shape --- tests/python/unittest/test_exc_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py index 0fa5fcb6f1e6..e3c333705260 100644 --- a/tests/python/unittest/test_exc_handling.py +++ b/tests/python/unittest/test_exc_handling.py @@ -204,7 +204,7 @@ def test_np_reshape_exception(): @with_seed() @use_np def test_np_random_incorrect_named_arguments(): - random_ops = ['uniform', 'normal', 'randint'] + random_ops = ['uniform', 'normal', 'randint', 'choice'] for op_name in random_ops: op = getattr(mx.np.random, op_name, None) assert op is not None From d5666eda70804931592fbb54ac5c526ab4587357 Mon Sep 17 00:00:00 2001 From: igolan <26796766+igolan@users.noreply.github.com> Date: Tue, 8 Oct 2019 15:41:14 -0400 Subject: [PATCH 4/5] Round and sign straight-through-estimators C operators. (#16373) * Implemented round and sign straight-through-estimators C operators. * fuxed lint --- src/operator/contrib/stes_op.cc | 84 +++++++++++ src/operator/contrib/stes_op.cu | 43 ++++++ src/operator/contrib/stes_op.h | 33 +++++ tests/python/unittest/test_contrib_stes_op.py | 137 ++++++++++++++++++ 4 files changed, 297 insertions(+) create mode 100644 src/operator/contrib/stes_op.cc create mode 100644 src/operator/contrib/stes_op.cu create mode 100644 src/operator/contrib/stes_op.h create mode 100644 tests/python/unittest/test_contrib_stes_op.py diff --git a/src/operator/contrib/stes_op.cc b/src/operator/contrib/stes_op.cc new file mode 100644 index 000000000000..c334d4d1b59c --- /dev/null +++ b/src/operator/contrib/stes_op.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * Copyright (c) 2019 by Contributors + * \file stes_op.cc + * \Straight-through-estimators round and sign operators. + * \author Itay Golan + */ + +#include "stes_op.h" + + +namespace mxnet { +namespace op { + +// Round STE +MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_round_ste, cpu, mshadow_op::round) +.describe(R"code(Straight-through-estimator of `round()`. + +In forward pass, returns element-wise rounded value to the nearest integer of the input (same as `round()`). + +In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in `round()`): +:math:`\frac{d}{dx}{round\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{round(x)} = 0`. +This is useful for quantized training. + +Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. + +Example:: + x = round_ste([-1.5, 1.5, -1.9, 1.9, 2.7]) + x.backward() + x = [-2., 2., -2., 2., 3.] + x.grad() = [1., 1., 1., 1., 1.] + +The storage type of ``round_ste`` output depends upon the input storage type: + - round_ste(default) = default + - round_ste(row_sparse) = row_sparse + - round_ste(csr) = csr +)code" ADD_FILELINE) +.set_attr("FGradient", CloneGradient{"_backward_round_ste"}); + +// sign +MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_sign_ste, cpu, mshadow_op::sign) +.describe(R"code(Straight-through-estimator of `sign()`. + +In forward pass, returns element-wise sign of the input (same as `sign()`). + +In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in ``sign()``): +:math:`\frac{d}{dx}{sign\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{sign(x)} = 0`. +This is useful for quantized training. + +Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. + +Example:: + x = sign_ste([-2, 0, 3]) + x.backward() + x = [-1., 0., 1.] + x.grad() = [1., 1., 1.] + +The storage type of ``sign_ste`` output depends upon the input storage type: + - round_ste(default) = default + - round_ste(row_sparse) = row_sparse + - round_ste(csr) = csr +)code" ADD_FILELINE) +.set_attr("FGradient", CloneGradient{"_backward_sign_ste"}); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/stes_op.cu b/src/operator/contrib/stes_op.cu new file mode 100644 index 000000000000..85e3ddaf206f --- /dev/null +++ b/src/operator/contrib/stes_op.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * Copyright (c) 2019 by Contributors + * \file stes_op.cu + * \Straight-through-estimators round and sign operators. + * \author Itay Golan + */ + +#include "stes_op.h" + +namespace mxnet { +namespace op { + +// Round STE +NNVM_REGISTER_OP(_contrib_round_ste) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FComputeEx", UnaryOp::ComputeEx); + +// Sign STE +NNVM_REGISTER_OP(_contrib_sign_ste) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FComputeEx", UnaryOp::ComputeEx); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/stes_op.h b/src/operator/contrib/stes_op.h new file mode 100644 index 000000000000..7185fbf1d6e1 --- /dev/null +++ b/src/operator/contrib/stes_op.h @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * Copyright (c) 2019 by Contributors + * \file stes_op.h + * \Straight-through-estimators round and sign operators. + * \author Itay Golan + */ + +#ifndef MXNET_OPERATOR_CONTRIB_STES_OP_H_ +#define MXNET_OPERATOR_CONTRIB_STES_OP_H_ + +#include +#include "../tensor/elemwise_unary_op.h" + +#endif // MXNET_OPERATOR_CONTRIB_STES_OP_H_ diff --git a/tests/python/unittest/test_contrib_stes_op.py b/tests/python/unittest/test_contrib_stes_op.py new file mode 100644 index 000000000000..5864ec9db5b1 --- /dev/null +++ b/tests/python/unittest/test_contrib_stes_op.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from common import with_seed +import mxnet as mx +from mxnet import nd, autograd, gluon +from mxnet.test_utils import default_context + + +class RoundSTENET(gluon.HybridBlock): + def __init__(self, w_init, **kwargs): + super(RoundSTENET, self).__init__(**kwargs) + with self.name_scope(): + self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') + + @staticmethod + def expected_grads(in_data, w_init): + return (in_data * w_init).round() + (in_data * w_init) + + @staticmethod + def expected_output(in_data, w_init): + return (in_data * w_init).round() * w_init + + def hybrid_forward(self, F, x, w): + # Simple forward function: round_ste(w*x)*w + out = w * x + out = F.contrib.round_ste(out) + # Uncomment to see how test fails with round + # out = F.round(out) + out = out * w + return out + + +class SignSTENET(gluon.HybridBlock): + def __init__(self, w_init, **kwargs): + super(SignSTENET, self).__init__(**kwargs) + with self.name_scope(): + self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') + + @staticmethod + def expected_grads(in_data, w_init): + return (in_data * w_init).sign() + (in_data * w_init) + + @staticmethod + def expected_output(in_data, w_init): + return (in_data * w_init).sign() * w_init + + def hybrid_forward(self, F, x, w): + # Simple forward function: sign_ste(w*x)*w + out = w * x + out = F.contrib.sign_ste(out) + # Uncomment to see how test fails with sign + # out = F.sign(out) + out = out * w + return out + + +def check_ste(net_type_str, w_init, hybridize, in_data, ctx=None): + ctx = ctx or default_context() + + net = eval(net_type_str)(w_init=w_init) + if hybridize: + net.hybridize() + # Init + net.collect_params().initialize(mx.init.Constant([w_init]), ctx=ctx) + + # Test: + in_data = in_data.as_in_context(ctx) + with mx.autograd.record(): + out = net(in_data) + assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \ + " expected " + str(net.expected_output(in_data, w_init)) + + out.backward() + assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \ + str(net.w.grad()) + " but expected " + \ + str(net.expected_grads(in_data, w_init)) + with mx.autograd.record(): + out = net(in_data) + assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \ + " expected " + str(net.expected_output(in_data, w_init)) + out.backward() + assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \ + str(net.w.grad()) + " but expected " + \ + str(net.expected_grads(in_data, w_init)) + +@with_seed() +def test_contrib_round_ste(): + # Test with random data + in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers + w_init = float(nd.uniform(-10, 10, shape=1).asscalar()) + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) + + # Test 1.5 (verifies that .5 rounds the same as in round) + in_data = nd.array([1.5]*30) # 10 and 30 are arbitrary numbers + w_init = 1. + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) + + # Test 0 + in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers + w_init = 0. + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) + check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) + + +@with_seed() +def test_contrib_sign_ste(): + in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers + w_init = float(nd.uniform(-10, 10, shape=1).asscalar()) + check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data) + check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data) + + # Test 0 + in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers + w_init = 0. + check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data) + check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data) + +if __name__ == '__main__': + import nose + nose.runmodule() \ No newline at end of file From 15ea40d9bf7da3c4618ca45a6d023d9b0fb1c295 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 8 Oct 2019 16:23:51 -0700 Subject: [PATCH 5/5] Add boolean ndarray (#15940) * Initial infra of boolean ndarray Add np.equal implemented using tvmop Fix setting DLDataType conversion for boolean ndarray Add equal_gpu Fix inputs with different ndims Fix copying boolean ndarrays across devices Refactor binary logic op impl by tvm Add more logic ops Refactor TVMOpModule::Call to CallEx Add binary scalar logic op expr and schedule Add binary scalar logic ops Add free functions for logic ops Rebase with master to fix SetDLTensor bug Fix pylint Add sum op for boolean ndarrays using tvm op module Add sum boolean gpu compute Add bool type support to boolean_mask Boolean indexing working Clean up Fix merge Sync Makefile Rebase Add boolean indexing test Fix sanity Fix gpu and add autograd test Rebase Fix test for windows Fix tests Try to fix cuda arch missing error in ci Fix ci Fix windows build Try to fix cmake Fix cmake Fix Revert config.mk * Fix cmake * Skip compute capability <= 52 for TVM generated ops * Fix sanity --- 3rdparty/mshadow/mshadow/base.h | 62 +++- CMakeLists.txt | 6 +- Makefile | 6 +- ci/docker/runtime_functions.sh | 4 +- contrib/tvmop/__init__.py | 1 + contrib/tvmop/compile.py | 41 ++- contrib/tvmop/core/__init__.py | 18 + contrib/tvmop/core/fromnumeric.py | 63 ++++ contrib/tvmop/core/umath.py | 122 +++++++ contrib/tvmop/opdef.py | 6 +- contrib/tvmop/utils.py | 16 +- include/mxnet/tensor_blob.h | 1 + python/mxnet/ndarray/ndarray.py | 6 + python/mxnet/ndarray/numpy/_op.py | 198 ++++++++++- python/mxnet/numpy/multiarray.py | 316 ++++++++++++++---- python/mxnet/numpy/utils.py | 7 +- python/mxnet/symbol/numpy/_symbol.py | 246 +++++++++++--- python/mxnet/test_utils.py | 17 + src/ndarray/ndarray.cc | 2 +- src/ndarray/ndarray_function.cc | 9 + src/ndarray/ndarray_function.cu | 10 +- src/operator/contrib/boolean_mask.cc | 7 +- src/operator/contrib/boolean_mask.cu | 4 +- src/operator/mxnet_op.h | 16 + src/operator/numpy/np_broadcast_reduce_op.h | 21 +- .../numpy/np_broadcast_reduce_op_value.cc | 71 ++++ .../numpy/np_elemwise_broadcast_op.cc | 253 +++++++++++++- src/operator/operator_tune.cc | 23 +- .../elemwise_binary_broadcast_op_logic.cc | 6 - .../tensor/elemwise_binary_scalar_op_logic.cc | 6 - src/operator/tensor/elemwise_unary_op.h | 2 +- src/operator/tensor/init_op.h | 10 +- src/operator/tvmop/op_module.cc | 27 +- src/operator/tvmop/op_module.h | 18 +- tests/python/unittest/test_numpy_gluon.py | 12 +- tests/python/unittest/test_numpy_ndarray.py | 149 ++++++++- tests/python/unittest/test_numpy_op.py | 35 +- 37 files changed, 1620 insertions(+), 197 deletions(-) create mode 100644 contrib/tvmop/core/__init__.py create mode 100644 contrib/tvmop/core/fromnumeric.py create mode 100644 contrib/tvmop/core/umath.py diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 1076f122c3ae..e0e9602c00db 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -311,6 +311,7 @@ enum TypeFlag { kInt32 = 4, kInt8 = 5, kInt64 = 6, + kBool = 7, }; template @@ -411,6 +412,11 @@ struct DataType { static const int kFlag = kInt64; static const int kLanes = 1; }; +template<> +struct DataType { + static const int kFlag = kBool; + static const int kLanes = 1; +}; /*! \brief type enum value for default real type */ const int default_type_flag = DataType::kFlag; @@ -1138,10 +1144,64 @@ struct minimum { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kBool: \ + { \ + typedef bool DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! \brief get data type size from type enum */ inline size_t mshadow_sizeof(int type) { int size = 0; - MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType);); + MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType);); return size; } diff --git a/CMakeLists.txt b/CMakeLists.txt index f441e9b0bd3b..5045bba9d989 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -756,11 +756,15 @@ if(USE_TVM_OP) endif() endif() + set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so") + if(CUDA_ARCH_BIN) + set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS}" "--cuda-arch" "${CUDA_ARCH_BIN}") + endif() add_custom_command(TARGET mxnet POST_BUILD COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib" LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}:${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm:$ENV{LD_LIBRARY_PATH} - ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so + ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py ${TVM_OP_COMPILE_OPTIONS} ) endif() diff --git a/Makefile b/Makefile index b3b188a2cf3b..0a1e355ee5e8 100644 --- a/Makefile +++ b/Makefile @@ -630,11 +630,15 @@ lib/libtvm_runtime.so: ls $(ROOTDIR)/lib; \ cd $(ROOTDIR) +TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so +ifneq ($(CUDA_ARCH),) + TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)" +endif lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py) echo "Compile TVM operators" PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \ LD_LIBRARY_PATH=$(ROOTDIR)/lib \ - python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so + python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS) NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h) NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b2a50f30af4e..7d91d4e5d121 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -228,7 +228,7 @@ build_ubuntu_gpu_mkldnn_release() { # $1 -> mxnet_variant: the mxnet variant to build, e.g. cpu, cu100, cu92mkl, etc. build_dynamic_libmxnet() { set -ex - + local mxnet_variant=${1:?"This function requires a mxnet variant as the first argument"} # relevant licenses will be placed in the licenses directory @@ -948,7 +948,7 @@ cd_unittest_ubuntu() { fi $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/unittest - $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization + $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization # https://github.com/apache/incubator-mxnet/issues/11801 # if [[ ${mxnet_variant} = "cpu" ]] || [[ ${mxnet_variant} = "mkl" ]]; then diff --git a/contrib/tvmop/__init__.py b/contrib/tvmop/__init__.py index 1234ee7d31f1..db41574ee867 100644 --- a/contrib/tvmop/__init__.py +++ b/contrib/tvmop/__init__.py @@ -21,3 +21,4 @@ from .utils import assign_by_req, reduce_axes from . import basic +from . import core diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py index e6af0a276560..3c0efdd6b806 100644 --- a/contrib/tvmop/compile.py +++ b/contrib/tvmop/compile.py @@ -21,7 +21,13 @@ import os import argparse +import re +import logging from tvmop.opdef import __OP_DEF__ +from tvm.autotvm.measure.measure_methods import set_cuda_target_arch + +logging.basicConfig(level=logging.INFO) + def get_target(device): if device == "cpu": @@ -31,12 +37,39 @@ def get_target(device): assert False, "Unknown device " + device +def get_cuda_arch(arch): + if arch is None: + return None + + if not isinstance(arch, str): + raise TypeError('Expecting parameter arch as a str, while got a {}'.format(str(type(arch)))) + + if len(arch) == 0: + return None + + # the arch string contains '-arch=sm_xx' + flags = arch.split() + for flag in flags: + if flag.startswith('-arch='): + return flag[len('-arch='):] + + # find the highest compute capability + comp_caps = re.findall(r'\d+', arch) + if len(comp_caps) == 0: + return None + + comp_caps = [int(c) for c in comp_caps] + return 'sm_' + str(max(comp_caps)) + + if __name__ == "__main__": import sys sys.path.append(os.path.dirname(sys.path[0])) parser = argparse.ArgumentParser(description="Generate tvm operators") parser.add_argument("-o", action="store", required=True, dest="target_path", help="Target path which stores compiled library") + parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch', + help='The cuda arch for compiling kernels for') arguments = parser.parse_args() func_list_llvm = [] @@ -52,8 +85,14 @@ def get_target(device): binds=operator_def.get_binds(args)) func_list.append(func_lower) - lowered_funcs = {get_target("cpu") : func_list_llvm} + lowered_funcs = {get_target("cpu"): func_list_llvm} if len(func_list_cuda) > 0: lowered_funcs[get_target("cuda")] = func_list_cuda + cuda_arch = get_cuda_arch(arguments.cuda_arch) + if cuda_arch is None: + logging.info('No cuda arch specified. TVM will try to detect it from the build platform.') + else: + logging.info('Cuda arch {} set for compiling TVM operator kernels.'.format(cuda_arch)) + set_cuda_target_arch(cuda_arch) func_binary = tvm.build(lowered_funcs, name="tvmop") func_binary.export_library(arguments.target_path) diff --git a/contrib/tvmop/core/__init__.py b/contrib/tvmop/core/__init__.py new file mode 100644 index 000000000000..841d4ad9db27 --- /dev/null +++ b/contrib/tvmop/core/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import umath, fromnumeric diff --git a/contrib/tvmop/core/fromnumeric.py b/contrib/tvmop/core/fromnumeric.py new file mode 100644 index 000000000000..e6c4c2be0814 --- /dev/null +++ b/contrib/tvmop/core/fromnumeric.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +from .. import defop +from ..utils import reduce_axes, assign_by_req + + +def _compute_sum(itype, otype, ndim, reduce1st_dim, req): + axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim] + a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=itype) + reduce_output = reduce_axes(a, axes, tvm.sum, otype) + output_placeholder, final_output = assign_by_req(reduce_output, req) + s = tvm.create_schedule(final_output.op) + return s, a, output_placeholder, final_output, [reduce_output, final_output] + + +@defop(name='sum_cpu', target='cpu', itype=['bool'], + otype=['float32', 'float64', 'int32', 'int64'], + ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1], + attrs=["reduce1st_dim", "req"]) +def _sum_cpu(itype, otype, ndim, reduce1st_dim, req): + s, a, output_placeholder, final_output, tensor_list = _compute_sum( + itype, otype, ndim, reduce1st_dim, req) + for t in tensor_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [a, output_placeholder, final_output] + + +@defop(name='sum_gpu', target='gpu', itype=['bool'], + otype=['float32', 'float64', 'int32', 'int64'], + ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1], + attrs=["reduce1st_dim", "req"]) +def _sum_gpu(itype, otype, ndim, reduce1st_dim, req): + s, a, output_placeholder, final_output, tensor_list = _compute_sum( + itype, otype, ndim, reduce1st_dim, req) + num_threads = 64 + for t in tensor_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_threads) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [a, output_placeholder, final_output] diff --git a/contrib/tvmop/core/umath.py b/contrib/tvmop/core/umath.py new file mode 100644 index 000000000000..ad099299aae5 --- /dev/null +++ b/contrib/tvmop/core/umath.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from .. import defop, AllTypes + +_bin_logic_op_map = { + 'equal': lambda a, b, *idx: a[idx] == b[idx], + 'not_equal': lambda a, b, *idx: a[idx] != b[idx], + 'greater': lambda a, b, *idx: a[idx] > b[idx], + 'less': lambda a, b, *idx: a[idx] < b[idx], + 'greater_equal': lambda a, b, *idx: a[idx] >= b[idx], + 'less_equal': lambda a, b, *idx: a[idx] <= b[idx], +} + + +def _compute_binary_logic(op, dtype, ndim): + a = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='a') + b = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='b') + c = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c') + s = tvm.create_schedule(c.op) + return s, a, b, c + + +_bin_logic_cpu_attrs = { + 'compute_func': _compute_binary_logic, + 'target': 'cpu', + 'auto_broadcast': True, + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + +_bin_logic_gpu_attrs = { + 'compute_func': _compute_binary_logic, + 'target': 'gpu', + 'auto_broadcast': True, + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + + +def _binary_logic_cpu(compute_func, op, itype, ndim): + s, a, b, c = compute_func(op, itype, ndim) + axes = [axis for axis in c.op.axis] + fused = s[c].fuse(*axes) + s[c].parallel(fused) + return s, [a, b, c] + + +def _binary_logic_gpu(compute_func, op, itype, ndim): + s, a, b, c = compute_func(op, itype, ndim) + axes = [axis for axis in c.op.axis] + fused = s[c].fuse(*axes) + bx, tx = s[c].split(fused, factor=64) + s[c].bind(bx, tvm.thread_axis('blockIdx.x')) + s[c].bind(tx, tvm.thread_axis('threadIdx.x')) + return s, [a, b, c] + + +# register binary element-wise logic ops with broadcasting supported +for op_name in _bin_logic_op_map.keys(): + defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu) + defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu) + + +# Note that `b.dtype` is hard-coded as 'float64'. +# We should always promote `a`'s elements to `b.dtype`. +_bin_scalar_logic_op_map = { + 'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b, + 'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b, + 'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b, + 'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b, + 'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b, + 'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b, +} + + +def _compute_binary_scalar_logic(op, dtype, ndim): + a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=dtype) + b = tvm.var('b', dtype='float64') + c = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c') + s = tvm.create_schedule(c.op) + return s, a, b, c + + +_bin_scalar_logic_cpu_attrs = { + 'compute_func': _compute_binary_scalar_logic, + 'target': 'cpu', + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + +_bin_scalar_logic_gpu_attrs = { + 'compute_func': _compute_binary_scalar_logic, + 'target': 'gpu', + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + + +# register binary element-wise scalar logic ops +for op_name in _bin_scalar_logic_op_map.keys(): + defop(name='{}_cpu'.format(op_name), op=op_name, + **_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu) + defop(name='{}_gpu'.format(op_name), op=op_name, + **_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu) diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py index 32d1832d13dd..39c42f4dd465 100644 --- a/contrib/tvmop/opdef.py +++ b/contrib/tvmop/opdef.py @@ -74,11 +74,12 @@ def __call__(self, *args, **kwargs): def invoke_all(self): for each_kwargs in self.arg_combination: - if (self.attrs_valid(**each_kwargs)): + if self.attrs_valid(**each_kwargs): sch, args = self.func(**each_kwargs) name = self.name \ + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \ - + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args]) + + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) + for arg in args if hasattr(arg, 'shape')]) yield sch, args, name def get_binds(self, args): @@ -107,6 +108,7 @@ def defop(name, target=None, auto_broadcast=False, **kwargs): """ assert name is not None and len(name) > 0 target = "cpu" if target is None else target + def _defop(func): opdef = OpDef(func, name, target, auto_broadcast, **kwargs) __OP_DEF__.append(opdef) diff --git a/contrib/tvmop/utils.py b/contrib/tvmop/utils.py index 329dce2148d9..39d7a8092005 100644 --- a/contrib/tvmop/utils.py +++ b/contrib/tvmop/utils.py @@ -21,16 +21,18 @@ AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"] RealTypes = ["float32", "float64", "float16"] -def assign_by_req(a, req): + +def assign_by_req(a, req, otype=None): b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) - if (req == "kAddTo"): - c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) + if req == "kAddTo": + c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx] + if otype else a[idx] + b[idx]) else: - c = tvm.compute(a.shape, lambda *idx: a[idx]) + c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx]) return b, c -def reduce_axes(X, axes, reducer): +def reduce_axes(X, axes, reducer, atype=None): def get_index(idx, ridx): j = 0 k = 0 @@ -45,5 +47,7 @@ def get_index(idx, ridx): odim = (len(ishape) + 1 - axes[0]) // 2 oshape = [tvm.var() for _ in range(odim)] ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] - ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret') + ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype) + if atype else X[get_index(idx, ridx)], + axis=ridx), name='ret') return ret diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 70af75424252..f64703092121 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -382,6 +382,7 @@ class TBlob { case mshadow::kInt32: return DLDataType{kDLInt, 32, 1}; case mshadow::kInt8: return DLDataType{kDLInt, 8, 1}; case mshadow::kInt64: return DLDataType{kDLInt, 64, 1}; + case mshadow::kBool: return DLDataType{kDLUInt, 1, 1}; default: { LOG(FATAL) << "Unknown type_flag=" << type_flag; return DLDataType(); diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 0b72865dc17a..4e3c7efa7be3 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -69,6 +69,7 @@ np.int32: 4, np.int8: 5, np.int64: 6, + np.bool_: 7, } _DTYPE_MX_TO_NP = { @@ -80,6 +81,7 @@ 4: np.int32, 5: np.int8, 6: np.int64, + 7: np.bool_, } _STORAGE_TYPE_STR_TO_ID = { @@ -2995,6 +2997,10 @@ def get_indexing_dispatch_code(key): for idx in key: if isinstance(idx, (NDArray, np.ndarray, list, tuple)): + if getattr(idx, 'dtype', None) == np.bool_: + raise TypeError('ndarray indexing does not support boolean ndarray' + ' in a tuple of indices. Only single boolean ndarray' + ' as an index is supported.') return _NDARRAY_ADVANCED_INDEXING elif sys.version_info[0] > 2 and isinstance(idx, range): return _NDARRAY_ADVANCED_INDEXING diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 5ffb40dae7e0..9dce953004cf 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,8 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer'] + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] @set_module('mxnet.ndarray.numpy') @@ -3649,3 +3650,198 @@ def vdot(a, b): 30 """ return tensordot(a.flatten(), b.flatten(), 1) + + +@set_module('mxnet.ndarray.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar, + _npi.less_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar, + _npi.less_equal_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar, + _npi.greater_equal_scalar, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 01b2faa0c056..75b7cf65325b 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -29,7 +29,6 @@ from builtins import slice as py_slice from array import array as native_array -import sys import ctypes import warnings import numpy as _np @@ -38,7 +37,7 @@ get_oshape_of_gather_nd_op from ..ndarray._internal import _set_np_ndarray_class from . import _op as _mx_np_op -from ..base import check_call, _LIB, NDArrayHandle +from ..base import check_call, _LIB, NDArrayHandle, c_array from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types from ..context import Context from ..util import _sanity_check_params, set_module @@ -54,7 +53,8 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -86,6 +86,35 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): return hdl +def _reshape_view(a, *shape): + """Returns a **view** of this array with a new shape without altering any data. + + Parameters + ---------- + shape : tuple of int, or n ints + The new shape should not change the array size, namely + ``np.prod(new_shape)`` should be equal to ``np.prod(a.shape)``. + Some dimensions of the shape can take special value -1, which + infers the dimension of the output shape by using the remainder of the + input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + Returns + ------- + ndarray + An array with desired shape that shares data with this array. + """ + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + handle = NDArrayHandle() + check_call(_LIB.MXNDArrayReshape64(a.handle, + len(shape), + c_array(ctypes.c_int64, shape), + False, + ctypes.byref(handle))) + return ndarray(handle=handle, writable=a.writable) + + # Have to use 0 as default value for stype since pylint does not allow # importing _STORAGE_TYPE_DEFAULT from ndarray.py. def _np_ndarray_cls(handle, writable=True, stype=0): @@ -97,18 +126,6 @@ def _np_ndarray_cls(handle, writable=True, stype=0): _set_np_ndarray_class(_np_ndarray_cls) - -def _get_index(idx): - if isinstance(idx, NDArray) and not isinstance(idx, ndarray): - raise TypeError('Cannot have mx.nd.NDArray as index') - if isinstance(idx, ndarray): - return idx.as_nd_ndarray() - elif sys.version_info[0] > 2 and isinstance(idx, range): - return array(_np.arange(idx.start, idx.stop, idx.step, dtype=_np.int32)).as_nd_ndarray() - else: - return idx - - _NUMPY_ARRAY_FUNCTION_DICT = {} _NUMPY_ARRAY_UFUNC_DICT = {} @@ -272,8 +289,35 @@ def __getitem__(self, key): Overriding the method in NDArray class in a numpy fashion. Calling numpy ndarray's _get_np_basic_indexing(key) and _get_np_advanced_indexing(key). """ + # handling possible boolean indexing first ndim = self.ndim shape = self.shape + + if isinstance(key, list): + try: + new_key = _np.array(key) + if new_key.dtype == _np.bool_: + key = new_key + except Exception as err: + raise TypeError('{}'.format(str(err))) + if isinstance(key, _np.ndarray) and key.dtype == _np.bool_: + key = array(key, dtype='bool') + if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing + key_shape = key.shape + key_ndim = len(key_shape) + if ndim < key_ndim: + raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}' + .format(key_ndim, ndim)) + for i in range(key_ndim): + if key_shape[i] != shape[i]: + raise IndexError('boolean index did not match indexed array along dimension {};' + 'dimension is {} but corresponding boolean dimension is {}' + .format(i, shape[i], key_shape[i])) + remaining_dims = shape[key_ndim:] + data = _reshape_view(self, -1, *remaining_dims) + key = _reshape_view(key, -1) + return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims) + if ndim == 0: if key != (): raise IndexError('scalar tensor can only accept `()` as index') @@ -502,66 +546,30 @@ def __rpow__(self, other): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.equal(self, other) - elif isinstance(other, numeric_types): - return _npi.equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return equal(self, other) def __hash__(self): raise NotImplementedError def __ne__(self, other): """x.__ne__(y) <=> x != y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.not_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.not_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return not_equal(self, other) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.greater(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return greater(self, other) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.greater_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return greater_equal(self, other) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.less(self, other) - elif isinstance(other, numeric_types): - return _npi.less_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return less(self, other) def __le__(self, other): """x.__le__(y) <=> x <= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.less_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.less_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return less_equal(self, other) def __bool__(self): num_elements = self.size @@ -694,7 +702,7 @@ def __repr__(self): if 'dtype=' in array_str: if dtype == _np.float32: array_str = array_str[:array_str.rindex(',')] + ')' - elif dtype != _np.float32: + elif dtype not in (_np.float32, _np.bool_): array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__) context = self.context @@ -5165,3 +5173,195 @@ def vdot(a, b): 30 """ return tensordot(a.flatten(), b.flatten(), 1) + + +@set_module('mxnet.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _mx_nd_np.equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.not_equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.greater(x1, x2, out) + + +@set_module('mxnet.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.less(x1, x2, out) + + +@set_module('mxnet.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _mx_nd_np.greater_equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _mx_nd_np.less_equal(x1, x2, out) diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py index 920897efc80b..b2d0dd96d324 100644 --- a/python/mxnet/numpy/utils.py +++ b/python/mxnet/numpy/utils.py @@ -22,7 +22,8 @@ import numpy as onp -__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 'pi'] +__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', + 'bool', 'bool_', 'pi', 'inf', 'nan'] float16 = onp.float16 float32 = onp.float32 @@ -31,5 +32,9 @@ int32 = onp.int32 int8 = onp.int8 int64 = onp.int64 +bool_ = onp.bool_ +bool = onp.bool pi = onp.pi +inf = onp.inf +nan = onp.nan diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 7ba92e9cab6e..3eaf80a1b6fb 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,8 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer'] + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] def _num_outputs(sym): @@ -138,63 +139,27 @@ def __deepcopy__(self, _): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.equal(self, other) - elif isinstance(other, numeric_types): - return _npi.equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return equal(self, other) def __ne__(self, other): """x.__ne__(y) <=> x != y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.not_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.not_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return not_equal(self, other) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.greater(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return greater(self, other) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.greater_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return greater_equal(self, other) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.less(self, other) - elif isinstance(other, numeric_types): - return _npi.less_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return less(self, other) def __le__(self, other): """x.__le__(y) <=> x <= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.less_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.less_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return less_equal(self, other) def __len__(self): raise NotImplementedError @@ -3732,4 +3697,199 @@ def vdot(a, b): return tensordot(a.flatten(), b.flatten(), 1) +@set_module('mxnet.symbol.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar, + _npi.less_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar, + _npi.less_equal_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar, + _npi.greater_equal_scalar, out) + + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index abdf57039c3a..c5b3b0d8e4cd 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -50,6 +50,7 @@ from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol from .util import use_np # pylint: disable=unused-import +from .runtime import Features def default_context(): @@ -2225,3 +2226,19 @@ def collapse_sum_like(a, shape): def is_cd_run(): """Checks if the test is running as part of a Continuous Delivery run""" return os.environ.get("CD_JOB", 0) == "1" + + +_features = Features() + + +def has_tvm_ops(): + """Returns True if MXNet is compiled with TVM generated operators. If current ctx + is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported.""" + built_with_tvm_op = _features.is_enabled("TVM_OP") + if current_context().device_type == 'gpu': + try: + import tvm + except ImportError: + return False + return built_with_tvm_op and (int("".join(tvm.nd.gpu(0).compute_version.split('.'))) >= 53) + return built_with_tvm_op diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index cc21dd242a2d..d1c3132b79a7 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -281,7 +281,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { CHECK_EQ(storage_type(), kDefaultStorage); NDArray ret = this->Detach(); size_t length = shape_.ProdShape(1, shape_.ndim()); - MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); }); ret.reuse_ = false; diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 1a699b12d76d..653dec84563a 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -36,6 +36,15 @@ template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { + if (from.type_flag_ == mshadow::kBool || to->type_flag_ == mshadow::kBool) { + CHECK_EQ(from.type_flag_, to->type_flag_) << "Only supports copying data between" + " two boolean tensors."; + const index_t size = from.Size(); + CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(bool) + << " bytes, to: " << to->Size() * sizeof(bool) << " bytes."; + common::ParallelCopy(to->dptr(), from.dptr(), size); + return; + } MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { const index_t size = static_cast(from.Size()); diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 0ab40a794198..6439c417bfe3 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -44,7 +44,7 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); @@ -57,7 +57,7 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); @@ -70,12 +70,16 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { if (from_ctx.dev_id == to_ctx.dev_id) { mshadow::Stream* s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { mshadow::Copy(to->FlatTo1D(s), from.FlatTo1D(s), s); } else { + CHECK_NE(from.type_flag_, mshadow::kBool) + << "Copying boolean ndarray across devices is not supported"; + CHECK_NE(to->type_flag_, mshadow::kBool) + << "Copying boolean ndarray across devices is not supported"; MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 0afc57f6ab54..a54cc917776d 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -35,7 +35,7 @@ bool BooleanMaskType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0) != -1; + return in_attrs->at(0) != -1 && in_attrs->at(1) != -1 && out_attrs->at(0) != -1; } bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, @@ -115,7 +115,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, std::vector prefix_sum(idx_size, 0); size_t valid_num = 0; // Calculate prefix sum - MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), DType, { DType* idx_dptr = idx.data().dptr(); for (size_t i = 0; i < idx_size; i++) { prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1]; @@ -154,7 +154,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { size_t input_size = igrad_data.shape().Size(); size_t idx_size = idx.shape()[0]; size_t col_size = input_size / idx_size; @@ -179,6 +179,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_boolean_mask) +.add_alias("_npi_boolean_mask") .describe(R"code( Given an n-d NDArray data, and a 1-d NDArray index, the operator produces an un-predeterminable shaped n-d NDArray out, diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index c4a06d25d70a..71d91c63f64e 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -66,7 +66,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); @@ -129,7 +129,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d90aa268195a..b46ce8a598d9 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -491,6 +491,17 @@ struct AccType { .add_enum("int64", mshadow::kInt64) +#define MXNET_ADD_ALL_TYPES_WITH_BOOL \ + .add_enum("float32", mshadow::kFloat32) \ + .add_enum("float64", mshadow::kFloat64) \ + .add_enum("float16", mshadow::kFloat16) \ + .add_enum("uint8", mshadow::kUint8) \ + .add_enum("int8", mshadow::kInt8) \ + .add_enum("int32", mshadow::kInt32) \ + .add_enum("int64", mshadow::kInt64) \ + .add_enum("bool", mshadow::kBool) + + /* \brief Compute flattened index given coordinates and shape. */ template MSHADOW_XINLINE index_t ravel(const Shape& coord, const Shape& shape) { @@ -597,6 +608,11 @@ template MSHADOW_CINLINE void copy(mshadow::Stream *s, const TBlob& to, const TBlob& from) { CHECK_EQ(from.Size(), to.Size()); CHECK_EQ(from.dev_mask(), to.dev_mask()); + if (from.type_flag_ == mshadow::kBool || to.type_flag_ == mshadow::kBool) { + CHECK_EQ(from.type_flag_, to.type_flag_) << "Only supports copying between boolean ndarrays."; + mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); + return; + } MSHADOW_TYPE_SWITCH(to.type_flag_, DType, { if (to.type_flag_ == from.type_flag_) { mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index b1348de72de5..816091a48b48 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -27,6 +27,7 @@ #include #include +#include #include "../nn/moments-inl.h" #include "../tensor/broadcast_reduce_op.h" @@ -86,7 +87,6 @@ struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter>& axis, bool keepdims) { - // TODO(junwu): improve the logic // If input is a scalar, output should be a scalar too if (ishape.ndim() == 0) { if (axis.has_value()) { @@ -210,6 +210,10 @@ inline bool NeedSafeAcc(int itype, int otype) { return safe_acc_hint && rule; } +void TVMOpReduce(const OpContext& ctx, const TBlob& input, + const dmlc::optional>& axis, + const TBlob& output, const OpReqType req, const std::string& reducer_name); + template void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, @@ -217,6 +221,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (req[0] == kNullOp) return; const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; @@ -230,6 +235,18 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, }); return; } + CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place"; + // If boolean ndarray, use the kernel generated by TVM + if (inputs[0].type_flag_ == mshadow::kBool) { + std::string reducer_name; + if (std::is_same::value) { + reducer_name = "sum"; + } else { + LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays"; + } + TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name); + return; + } if (param.axis.has_value() && param.axis.value().ndim() == 0) { UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); } @@ -279,6 +296,8 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + CHECK_NE(outputs[0].type_flag_, kBool) << "reduce operators do not support gradient calculation " + "for input tensors of boolean type."; const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); TShape small; if (param.keepdims) { diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 774bc11f5de8..fdda792a9ed8 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -23,6 +23,10 @@ * \brief CPU Implementation of broadcast and reduce functions based on value. */ +#if MXNET_USE_TVM_OP +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + #include "np_broadcast_reduce_op.h" namespace mxnet { @@ -40,7 +44,17 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, const NumpyReduceAxesParam ¶m = nnvm::get(attrs.parsed); if (param.dtype.has_value()) { + if (in_attrs->at(0) == mshadow::kBool) { + CHECK(param.dtype.value() == mshadow::kInt32 + || param.dtype.value() == mshadow::kInt64 + || param.dtype.value() == mshadow::kFloat32 + || param.dtype.value() == mshadow::kFloat64) << "Only support the following output " + "dtypes when input dtype is bool: " + "int32, int64, float32, float64."; + } TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); } else { TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); @@ -49,6 +63,63 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; } +#if MXNET_USE_TVM_OP +static constexpr int max_reduce_ndim = 5; +TBlob PrependAxes(const TBlob& src, const int dst_ndim); +#endif // MXNET_USE_TVM_OP + +void TVMOpReduce(const OpContext& ctx, + const TBlob& input, + const dmlc::optional>& axis, + const TBlob& output, + const OpReqType req, + const std::string& reducer_name) { +#if MXNET_USE_TVM_OP + CHECK_GE(input.ndim(), output.ndim()); + CHECK_LE(input.ndim(), max_reduce_ndim) << "TVMOpReduce only supports ndim <= " + << max_reduce_ndim; + + const TBlob expanded_output = (input.ndim() == output.ndim() ? + output : output.reshape(NumpyReduceAxesShapeImpl(input.shape_, axis, true))); + CHECK_EQ(input.ndim(), expanded_output.ndim()); + int reduce1st_dim = 0; + if (input.ndim() > 0 && input.size(0) != expanded_output.size(0)) { + reduce1st_dim = 1; + } + // collapse consecutive dimensions where reduction are performed or not performed + std::vector ishape_vec; + for (int i = 0; i < input.ndim(); ++i) { + if (i == 0 || ((input.size(i) != expanded_output.size(i)) + != (input.size(i-1) != expanded_output.size(i-1)))) { + ishape_vec.push_back(input.size(i)); + } else { + ishape_vec.back() *= input.size(i); + } + } + // append axes after collapsed ishape to reach the max ndim allowed + for (int i = ishape_vec.size(); i < max_reduce_ndim; ++i) { + ishape_vec.push_back(1); + } + std::vector oshape_vec; + for (size_t i = reduce1st_dim; i < ishape_vec.size(); i += 2) { + oshape_vec.push_back(ishape_vec[i]); + } + TShape ishape(ishape_vec.begin(), ishape_vec.end()), oshape(oshape_vec.begin(), oshape_vec.end()); + TBlob input_tvm = input.reshape(ishape); + TBlob output_tvm = output.reshape(oshape); + const std::string ctx_name = + (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU) ? "cpu" : "gpu"; + std::ostringstream func_name; + func_name << reducer_name << "_" + << (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU ? "cpu" : "gpu") + << "reduce1st_dim_" << reduce1st_dim + << "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo"); + tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm}); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag to enable TVM-generated kernels."; +#endif // MXNET_USE_TVM_OP +} + NNVM_REGISTER_OP(_np_sum) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index fd109f040b50..aa81a58c2890 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,6 +23,12 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ +#if MXNET_USE_TVM_OP +#include +#include +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + #include "../tensor/elemwise_binary_broadcast_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -113,6 +119,22 @@ NNVM_REGISTER_OP(_npi_lcm) .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -280,21 +302,222 @@ NNVM_REGISTER_OP(_backward_npi_hypot) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInferType", ElemwiseIntType<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") -.set_attr("FCompute", BinaryScalarOp::Compute); +static constexpr char func_equal_cpu[] = "equal_cpu"; +static constexpr char func_equal_gpu[] = "equal_gpu"; +static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; +static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; +static constexpr char func_greater_cpu[] = "greater_cpu"; +static constexpr char func_greater_gpu[] = "greater_gpu"; +static constexpr char func_less_cpu[] = "less_cpu"; +static constexpr char func_less_gpu[] = "less_gpu"; +static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; +static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; +static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; +static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; + +bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; + TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +TBlob PrependAxes(const TBlob& src, const int dst_ndim) { + CHECK_LE(src.shape_.ndim(), dst_ndim); + const int src_ndim = src.shape_.ndim(); + if (src_ndim == dst_ndim) return src; + mxnet::TShape dst_shape(dst_ndim, 1); + for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { + dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; + } + return src.reshape(dst_shape); +} + +struct TVMBinaryBroadcastCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; + std::vector type_codes; + std::vector values; + + const int ondim = outputs[0].shape_.ndim(); + const size_t num_args = inputs.size() + outputs.size(); + type_codes.resize(num_args); + values.resize(num_args); + for (size_t i = 0; i < num_args; ++i) { + tblobs[i] = PrependAxes(tblobs[i], ondim); + type_codes[i] = kArrayHandle; + values[i].v_handle = const_cast(&(tblobs[i].dltensor())); + } + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); + +#if MXNET_USE_CUDA +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); +#endif // MXNET_USE_CUDA + +bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +struct TVMBinaryBroadcastScalarCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], outputs[0]}; + std::vector type_codes; + std::vector values; + + const size_t num_args = 3; // one input tensor, one scalar param, and one output + type_codes.resize(num_args); + values.resize(num_args); + + // input tensor setup + type_codes[0] = kArrayHandle; + values[0].v_handle = const_cast(&(tblobs[0].dltensor())); + + // scalar param + type_codes[1] = kDLFloat; + values[1].v_float64 = nnvm::get(attrs.parsed); + + // output tensor + type_codes[2] = kArrayHandle; + values[2].v_handle = const_cast(&(tblobs[1].dltensor())); + + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"data"}; \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_cpu}) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("scalar", "float", "scalar input") + +static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; +static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; +static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; +static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; +static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; +static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; +static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; +static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; +static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; +static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; +static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; +static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal_scalar); + +#if MXNET_USE_CUDA +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal_scalar); +#endif // MXNET_USE_CUDA MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) .set_attr("FCompute", BinaryBroadcastCompute) diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index c690cb633306..49d0e23bdd19 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -59,6 +59,7 @@ IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int8_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t); +IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(bool); /*! * \brief Init variable used to facilitate registering a tunable operator during @@ -84,6 +85,16 @@ struct static_init_var { __macro$(__VA_ARGS__, int32_t); \ __macro$(__VA_ARGS__, int64_t); +#define MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(__macro$, ...) \ + __macro$(__VA_ARGS__, float); \ + __macro$(__VA_ARGS__, double); \ + __macro$(__VA_ARGS__, mshadow::half::half_t); \ + __macro$(__VA_ARGS__, uint8_t); \ + __macro$(__VA_ARGS__, int8_t); \ + __macro$(__VA_ARGS__, int32_t); \ + __macro$(__VA_ARGS__, int64_t); \ + __macro$(__VA_ARGS__, bool) + #define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \ namespace mxnet_op { \ template<> std::vector mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = \ @@ -183,9 +194,15 @@ struct static_init_var { #define IMPLEMENT_UNARY_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$) +#define IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$) + #define IMPLEMENT_BLANK_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$) +#define IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$) + #define IMPLEMENT_UNARY_WORKLOAD_BWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_BWD, __op$) @@ -206,7 +223,7 @@ struct static_init_var { * integer value */ OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ns_ = 5000; -IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::identity); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal); // NOLINT() @@ -370,8 +387,8 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() -IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() -IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc index e3c2e0e898d9..cd433e00a770 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc @@ -30,7 +30,6 @@ namespace mxnet { namespace op { MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_equal) -.add_alias("_npi_equal") .describe(R"code(Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. Example:: @@ -49,7 +48,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_not_equal) -.add_alias("_npi_not_equal") .describe(R"code(Returns the result of element-wise **not equal to** (!=) comparison operation with broadcasting. Example:: @@ -68,7 +66,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater) -.add_alias("_npi_greater") .describe(R"code(Returns the result of element-wise **greater than** (>) comparison operation with broadcasting. Example:: @@ -87,7 +84,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater_equal) -.add_alias("_npi_greater_equal") .describe(R"code(Returns the result of element-wise **greater than or equal to** (>=) comparison operation with broadcasting. Example:: @@ -106,7 +102,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser) -.add_alias("_npi_less") .describe(R"code(Returns the result of element-wise **lesser than** (<) comparison operation with broadcasting. Example:: @@ -125,7 +120,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser_equal) -.add_alias("_npi_less_equal") .describe(R"code(Returns the result of element-wise **lesser than or equal to** (<=) comparison operation with broadcasting. Example:: diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc index 87ba394c99b2..17e76153ebb2 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc @@ -71,32 +71,26 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_equal_scalar, mshadow_op::eq) -.add_alias("_npi_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_EqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_not_equal_scalar, mshadow_op::ne) -.add_alias("_npi_not_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_NotEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_scalar, mshadow_op::gt) -.add_alias("_npi_greater_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_equal_scalar, mshadow_op::ge) -.add_alias("_npi_greater_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_scalar, mshadow_op::lt) -.add_alias("_npi_less_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_equal_scalar, mshadow_op::le) -.add_alias("_npi_less_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserEqualScalar"); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index b50f89b8c7fa..22e7652a4019 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -341,7 +341,7 @@ class UnaryOp : public OpBase { break; case kAddTo: { Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); }); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index f3c405d7103c..8e8896e5c014 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -57,7 +57,7 @@ struct InitOpParam : public dmlc::Parameter { .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." "Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) - MXNET_ADD_ALL_TYPES + MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Target data type."); } }; @@ -342,12 +342,12 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp if (val == 0) { if (req != kAddTo) { if (b.dev_mask() == cpu::kDevMask && size < 50000) { - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { memset(b.dptr_, 0, size * sizeof(DType)); }); } else { // Optimize common use-case of filling with ones - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, Req>, xpu>::Launch( s, b.Size(), b.dptr()); @@ -357,7 +357,7 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp } } else if (is_integer && val == 1) { // Optimize common use-case of filling with ones - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, xpu>::Launch( s, b.Size(), b.dptr()); @@ -365,7 +365,7 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp }); } else { // Generic fill kernel from variable - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, xpu>::Launch( s, b.Size(), b.dptr(), static_cast(val)); diff --git a/src/operator/tvmop/op_module.cc b/src/operator/tvmop/op_module.cc index d1d1c1d45d53..ea7f73069698 100644 --- a/src/operator/tvmop/op_module.cc +++ b/src/operator/tvmop/op_module.cc @@ -72,6 +72,9 @@ PackedFunc GetFunction(const std::shared_ptr &module, case mshadow::kInt64: func_name << "int64"; break; + case mshadow::kBool: + func_name << "bool"; + break; default: LOG(FATAL) << "Unknown dtype " << arg.type_flag_; } @@ -82,7 +85,7 @@ PackedFunc GetFunction(const std::shared_ptr &module, void TVMOpModule::Call(const std::string &func_name, const mxnet::OpContext& ctx, - const std::vector &args) { + const std::vector &args) const { std::vector type_codes; std::vector values; @@ -112,6 +115,28 @@ void TVMOpModule::Call(const std::string &func_name, #endif } +void TVMOpModule::CallEx(const std::string &func_name, + const mxnet::OpContext& ctx, + const std::vector& tblobs, + TVMArgs tvm_args) const { + TVMRetValue rv; + +#if MXNET_USE_CUDA + int dev_type = (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU; + int dev_id = ctx.run_ctx.ctx.dev_id; + if (dev_type == kDLGPU) { + void *stream = static_cast(ctx.run_ctx.get_stream()->stream_); + TVMSetStream(dev_type, dev_id, stream); + } +#endif + GetFunction(module_ptr_, func_name, tblobs).CallPacked(tvm_args, &rv); +#if MXNET_USE_CUDA + if (dev_type == kDLGPU) { + TVMSetStream(dev_type, dev_id, nullptr); + } +#endif +} + } // namespace runtime } // namespace tvm #endif // MXNET_USE_TVM_OP diff --git a/src/operator/tvmop/op_module.h b/src/operator/tvmop/op_module.h index 04e97ef51f4e..d28dd1b5b0c2 100644 --- a/src/operator/tvmop/op_module.h +++ b/src/operator/tvmop/op_module.h @@ -36,6 +36,7 @@ namespace tvm { namespace runtime { +class TVMArgs; class Module; class TVMOpModule { public: @@ -44,7 +45,22 @@ class TVMOpModule { void Call(const std::string& func_name, const mxnet::OpContext& ctx, - const std::vector& args); + const std::vector& args) const; + + /*! + * \brief Launch operator kernels which have been pre-compiled into a lib file + * by TVM compiler. + * \param func_name Function name that corresponds to the operator kernel + * \param ctx Operator context that includes device and stream information. + * \param tblobs Tensor blobs whose dtype and shape information are extracted + * to construct the function name. Each configuration of dtype and shape has + * a unique kernel. + * \param tvm_args Arguments to be passed to kernel function. + */ + void CallEx(const std::string &func_name, + const mxnet::OpContext& ctx, + const std::vector& tblobs, + TVMArgs tvm_args) const; static TVMOpModule *Get() { static TVMOpModule inst; diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 52ceb718fb15..f24eb6a325bf 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -122,18 +122,18 @@ def test_np_loss_ndarray(): weighting = np.array([0.5, 1, 0.5, 1]) loss = gluon.loss.L1Loss() - assert np.sum(loss(output, label)) == 6. + assert float(np.sum(loss(output, label))) == 6. loss = gluon.loss.L1Loss(weight=0.5) - assert np.sum(loss(output, label)) == 3. + assert float(np.sum(loss(output, label))) == 3. loss = gluon.loss.L1Loss() - assert np.sum(loss(output, label, weighting)) == 5. + assert float(np.sum(loss(output, label, weighting))) == 5. loss = gluon.loss.L2Loss() - assert np.sum(loss(output, label)) == 7. + assert float(np.sum(loss(output, label))) == 7. loss = gluon.loss.L2Loss(weight=0.25) - assert np.sum(loss(output, label)) == 1.75 + assert float(np.sum(loss(output, label))) == 1.75 loss = gluon.loss.L2Loss() - assert np.sum(loss(output, label, weighting)) == 6 + assert float(np.sum(loss(output, label, weighting))) == 6 output = np.array([[0, 2], [1, 4]]) label = np.array([0, 1]) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 3a5e72b53d58..20b964c96a30 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -19,13 +19,14 @@ from __future__ import absolute_import from __future__ import division import os +import unittest import numpy as _np import mxnet as mx from mxnet import np, npx, autograd from mxnet.gluon import HybridBlock -from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np from common import with_seed, TemporaryDirectory -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops, assert_exception from mxnet.ndarray.ndarray import py_slice from mxnet.base import integer_types import scipy.stats as ss @@ -34,8 +35,23 @@ @with_seed() @use_np def test_np_empty(): - dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, None] - expected_dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, np.float32] + # (input dtype, expected output dtype) + dtype_pairs = [ + (np.int8, np.int8), + (np.int32, np.int32), + (np.float16, np.float16), + (np.float32, np.float32), + (np.float64, np.float64), + (np.bool_, np.bool_), + (np.bool, np.bool_), + ('int8', np.int8), + ('int32', np.int32), + ('float16', np.float16), + ('float32', np.float32), + ('float64', np.float64), + ('bool', np.bool_), + (None, np.float32), + ] orders = ['C', 'F', 'A'] shapes = [ (), @@ -49,7 +65,7 @@ def test_np_empty(): (1, 1, 1, 1), ] ctxes = [npx.current_context(), None] - for dtype, expected_dtype in zip(dtypes, expected_dtypes): + for dtype, expected_dtype in dtype_pairs: for shape in shapes: for order in orders: for ctx in ctxes: @@ -65,7 +81,8 @@ def test_np_empty(): @with_seed() @use_np def test_np_array_creation(): - dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, _np.bool, _np.bool_, + 'int8', 'int32', 'float16', 'float32', 'float64', 'bool', None] objects = [ [], (), @@ -76,7 +93,7 @@ def test_np_array_creation(): for dtype in dtypes: for src in objects: mx_arr = np.array(src, dtype=dtype) - assert mx_arr.context == mx.current_context() + assert mx_arr.ctx == mx.current_context() if isinstance(src, mx.nd.NDArray): np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) else: @@ -110,6 +127,8 @@ def check_zero_array_creation(shape, dtype): if dtype is None: assert mx_out.dtype == _np.float32 assert np_out.dtype == _np.float64 + else: + assert mx_out.dtype == np_out.dtype shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()] shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)] @@ -132,6 +151,10 @@ def check_zero_array_creation(shape, dtype): y = test_zeros_output_type(x) assert type(y[1]) == np.ndarray + for shape in shapes: + for dtype in [_np.bool, bool, _np.bool, 'bool']: + check_zero_array_creation(shape, dtype) + @with_seed() @use_np @@ -158,6 +181,8 @@ def check_ones_array_creation(shape, dtype): if dtype is None: assert mx_out.dtype == _np.float32 assert np_out.dtype == _np.float64 + else: + assert mx_out.dtype == np_out.dtype shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()] shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)] @@ -180,6 +205,10 @@ def check_ones_array_creation(shape, dtype): y = test_ones_output_type(x) assert type(y[1]) == np.ndarray + for shape in shapes: + for dtype in [_np.bool, bool, _np.bool, 'bool']: + check_ones_array_creation(shape, dtype) + @with_seed() @use_np @@ -235,13 +264,18 @@ def test_np_ndarray_binary_element_wise_ops(): '/': _np.divide, 'mod': _np.mod, 'pow': _np.power, - '==': _np.equal, - '>': _np.greater, - '>=': _np.greater_equal, - '<': _np.less, - '<=': _np.less_equal } + if has_tvm_ops(): + np_op_map.update({ + '==': _np.equal, + '!=': _np.not_equal, + '>': _np.greater, + '>=': _np.greater_equal, + '<': _np.less, + '<=': _np.less_equal + }) + def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) @@ -309,10 +343,16 @@ def hybrid_forward(self, F, x, *args): return x == self._scalar if not self._reverse else self._scalar == x else: return x == args[0] + elif self._op == '!=': + if self._scalar is not None: + return x != self._scalar if not self._reverse else self._scalar != x + else: + return x != args[0] else: print(self._op) assert False + logic_ops = ['==', '!=', '>', '<', '>=', '<='] @use_np def check_binary_op_result(shape1, shape2, op, dtype=None): if shape1 is None: @@ -348,6 +388,8 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray()) assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape + if op in logic_ops: + assert np_out.dtype == mx_out.dtype assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) else: get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse) @@ -360,6 +402,8 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): mx_out = get_mx_ret(mx_input1.as_np_ndarray()) assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape + if op in logic_ops: + assert np_out.dtype == mx_out.dtype assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) dtypes = [_np.float32, _np.float64, None] @@ -933,6 +977,87 @@ def test_np_multinomial(): mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1) +@with_seed() +@unittest.skipUnless(has_tvm_ops(), "Comparison ops are implemented using TVM") +@use_np +def test_np_ndarray_boolean_indexing(): + def test_single_bool_index(): + # adapted from numpy's test_indexing.py + # Single boolean index + a = np.array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=np.int32) + assert same(a[np.array(True, dtype=np.bool_)].asnumpy(), a[None].asnumpy()) + assert same(a[np.array(False, dtype=np.bool_)].asnumpy(), a[None][0:0].asnumpy()) + + def test_boolean_catch_exception(): + # adapted from numpy's test_indexing.py + arr = np.ones((5, 4, 3)) + + index = np.array([True], dtype=np.bool_) + assert_exception(arr.__getitem__, IndexError, index) + + index = np.array([False] * 6, dtype=np.bool_) + assert_exception(arr.__getitem__, IndexError, index) + + index = np.zeros((4, 4), dtype=bool) + assert_exception(arr.__getitem__, IndexError, index) + + assert_exception(arr.__getitem__, TypeError, (slice(None), index)) + + def test_boolean_indexing_onedim(): + # adapted from numpy's test_indexing.py + # Indexing a 2-dimensional array with + # boolean array of length one + a = np.array([[0., 0., 0.]]) + b = np.array([True], dtype=bool) + assert same(a[b].asnumpy(), a.asnumpy()) + + def test_boolean_indexing_twodim(): + # adapted from numpy's test_indexing.py + # Indexing a 2-dimensional array with + # 2-dimensional boolean array + a = np.array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=np.int32) + b = np.array([[ True, False, True], + [False, True, False], + [ True, False, True]], dtype=np.bool_) + assert same(a[b].asnumpy(), _np.array([1, 3, 5, 7, 9], dtype=a.dtype)) + assert same(a[b[1]].asnumpy(), _np.array([[4, 5, 6]], dtype=a.dtype)) + assert same(a[b[0]].asnumpy(), a[b[2]].asnumpy()) + + def test_boolean_indexing_list(): + # adapted from numpy's test_indexing.py + a = np.array([1, 2, 3], dtype=np.int32) + b = [True, False, True] + # Two variants of the test because the first takes a fast path + assert same(a[b].asnumpy(), _np.array([1, 3], dtype=a.dtype)) + (a[None, b], [[1, 3]]) + + def test_boolean_indexing_autograd(): + a = np.random.uniform(size=(3, 4, 5)) + a.attach_grad() + with mx.autograd.record(): + out_mx = a[a < 0.5] + out_mx.backward() + + a_np = a.asnumpy() + out_np = a_np[a_np < 0.5] + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-4, atol=1e-5, use_broadcast=False) + + a_grad_np = _np.zeros(a.shape, dtype=a.dtype) + a_grad_np[a_np < 0.5] = 1 + assert_almost_equal(a.grad.asnumpy(), a_grad_np, rtol=1e-4, atol=1e-5, use_broadcast=False) + + test_single_bool_index() + test_boolean_catch_exception() + test_boolean_indexing_onedim() + test_boolean_indexing_twodim() + test_boolean_indexing_list() + test_boolean_indexing_autograd() + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e6b3d41d37b0..45bcc517a641 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -29,15 +29,11 @@ from common import assertRaises, with_seed import random import scipy.stats as ss -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, collapse_sum_like -from mxnet.runtime import Features from mxnet.numpy_op_signature import _get_builtin_op +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops import platform -_features = Features() - - @with_seed() @use_np def test_np_tensordot(): @@ -585,13 +581,14 @@ def is_int(dtype): in_data_dim = random.choice([2, 3, 4]) shape = rand_shape_nd(in_data_dim, dim=3) acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', - 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + 'int8': 'int32', 'int32': 'int64', 'int64': 'int64', 'bool': 'int64'} for hybridize in [False, True]: for keepdims in [True, False]: for axis in ([i for i in range(in_data_dim)] + [(), None]): - for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: + for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']: for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: - if is_int(dtype) and not is_int(itype): + if (is_int(dtype) and not is_int(itype))\ + or (itype == 'bool' and dtype not in ('float32', 'float64', 'int32', 'int64')): continue # test gluon test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims) @@ -599,13 +596,23 @@ def is_int(dtype): test_sum.hybridize() if is_int(itype): x = _np.random.randint(-128, 128, shape, dtype=itype) - x = mx.nd.array(x) + x = np.array(x) + elif itype == 'bool': + x = _np.random.randint(0, 2, shape) < 1 + x = np.array(x, dtype='bool') else: - x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) - x = x.as_np_ndarray() - x.attach_grad() + x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype) expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) expected_ret = expected_ret.astype(dtype) + if itype == 'bool': # special handling of boolean ndarray + if has_tvm_ops(): + y = test_sum(x) + assert y.dtype == expected_ret.dtype + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5, + use_broadcast=False) + continue + + x.attach_grad() with mx.autograd.record(): y = test_sum(x) assert y.shape == expected_ret.shape @@ -1464,7 +1471,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): 'arccosh' : (lambda x: 1./(x**2 - 1.)**(1./2.), 2.0, 5.0), 'arctanh' : (lambda x: -1./(x**2 - 1.), -0.99, 0.99) } - if _features.is_enabled("TVM_OP"): + if has_tvm_ops(): funcs['rad2deg'] = (lambda x: 180. / _np.pi * _np.ones(x.shape), -1.0, 1.0) funcs['deg2rad'] = (lambda x: _np.pi / 180. * _np.ones(x.shape), -1.0, 1.0) ndim = random.choice([2, 3, 4]) @@ -2232,7 +2239,7 @@ def test_np_indices(): (2, 3, 4, 5, 6, 7) ] if platform.system() == 'Windows': - shapes = shapes[1:] #beacuse in numpy windows version, indces not support dimensions is empty tuple. + shapes = shapes[1:] # beacuse in numpy windows version, indces not support dimensions is empty tuple. for dtype in dtypes: for shape in shapes: np_out = _np.indices(dimensions=shape, dtype=dtype)