diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 1076f122c3ae..e0e9602c00db 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -311,6 +311,7 @@ enum TypeFlag { kInt32 = 4, kInt8 = 5, kInt64 = 6, + kBool = 7, }; template @@ -411,6 +412,11 @@ struct DataType { static const int kFlag = kInt64; static const int kLanes = 1; }; +template<> +struct DataType { + static const int kFlag = kBool; + static const int kLanes = 1; +}; /*! \brief type enum value for default real type */ const int default_type_flag = DataType::kFlag; @@ -1138,10 +1144,64 @@ struct minimum { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kBool: \ + { \ + typedef bool DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! \brief get data type size from type enum */ inline size_t mshadow_sizeof(int type) { int size = 0; - MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType);); + MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType);); return size; } diff --git a/CMakeLists.txt b/CMakeLists.txt index f441e9b0bd3b..1e3bbceb362f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -756,11 +756,15 @@ if(USE_TVM_OP) endif() endif() + set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so") + if(CUDA_ARCH_BIN) + set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS} --cuda-arch ${CUDA_ARCH_BIN}") + endif() add_custom_command(TARGET mxnet POST_BUILD COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib" LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}:${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm:$ENV{LD_LIBRARY_PATH} - ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so + ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py ${TVM_OP_COMPILE_OPTIONS} ) endif() diff --git a/Makefile b/Makefile index b3b188a2cf3b..0a1e355ee5e8 100644 --- a/Makefile +++ b/Makefile @@ -630,11 +630,15 @@ lib/libtvm_runtime.so: ls $(ROOTDIR)/lib; \ cd $(ROOTDIR) +TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so +ifneq ($(CUDA_ARCH),) + TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)" +endif lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py) echo "Compile TVM operators" PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \ LD_LIBRARY_PATH=$(ROOTDIR)/lib \ - python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so + python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS) NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h) NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b2a50f30af4e..7d91d4e5d121 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -228,7 +228,7 @@ build_ubuntu_gpu_mkldnn_release() { # $1 -> mxnet_variant: the mxnet variant to build, e.g. cpu, cu100, cu92mkl, etc. build_dynamic_libmxnet() { set -ex - + local mxnet_variant=${1:?"This function requires a mxnet variant as the first argument"} # relevant licenses will be placed in the licenses directory @@ -948,7 +948,7 @@ cd_unittest_ubuntu() { fi $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/unittest - $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization + $nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization # https://github.com/apache/incubator-mxnet/issues/11801 # if [[ ${mxnet_variant} = "cpu" ]] || [[ ${mxnet_variant} = "mkl" ]]; then diff --git a/contrib/tvmop/__init__.py b/contrib/tvmop/__init__.py index 1234ee7d31f1..db41574ee867 100644 --- a/contrib/tvmop/__init__.py +++ b/contrib/tvmop/__init__.py @@ -21,3 +21,4 @@ from .utils import assign_by_req, reduce_axes from . import basic +from . import core diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py index e6af0a276560..3c0efdd6b806 100644 --- a/contrib/tvmop/compile.py +++ b/contrib/tvmop/compile.py @@ -21,7 +21,13 @@ import os import argparse +import re +import logging from tvmop.opdef import __OP_DEF__ +from tvm.autotvm.measure.measure_methods import set_cuda_target_arch + +logging.basicConfig(level=logging.INFO) + def get_target(device): if device == "cpu": @@ -31,12 +37,39 @@ def get_target(device): assert False, "Unknown device " + device +def get_cuda_arch(arch): + if arch is None: + return None + + if not isinstance(arch, str): + raise TypeError('Expecting parameter arch as a str, while got a {}'.format(str(type(arch)))) + + if len(arch) == 0: + return None + + # the arch string contains '-arch=sm_xx' + flags = arch.split() + for flag in flags: + if flag.startswith('-arch='): + return flag[len('-arch='):] + + # find the highest compute capability + comp_caps = re.findall(r'\d+', arch) + if len(comp_caps) == 0: + return None + + comp_caps = [int(c) for c in comp_caps] + return 'sm_' + str(max(comp_caps)) + + if __name__ == "__main__": import sys sys.path.append(os.path.dirname(sys.path[0])) parser = argparse.ArgumentParser(description="Generate tvm operators") parser.add_argument("-o", action="store", required=True, dest="target_path", help="Target path which stores compiled library") + parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch', + help='The cuda arch for compiling kernels for') arguments = parser.parse_args() func_list_llvm = [] @@ -52,8 +85,14 @@ def get_target(device): binds=operator_def.get_binds(args)) func_list.append(func_lower) - lowered_funcs = {get_target("cpu") : func_list_llvm} + lowered_funcs = {get_target("cpu"): func_list_llvm} if len(func_list_cuda) > 0: lowered_funcs[get_target("cuda")] = func_list_cuda + cuda_arch = get_cuda_arch(arguments.cuda_arch) + if cuda_arch is None: + logging.info('No cuda arch specified. TVM will try to detect it from the build platform.') + else: + logging.info('Cuda arch {} set for compiling TVM operator kernels.'.format(cuda_arch)) + set_cuda_target_arch(cuda_arch) func_binary = tvm.build(lowered_funcs, name="tvmop") func_binary.export_library(arguments.target_path) diff --git a/contrib/tvmop/core/__init__.py b/contrib/tvmop/core/__init__.py new file mode 100644 index 000000000000..841d4ad9db27 --- /dev/null +++ b/contrib/tvmop/core/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import umath, fromnumeric diff --git a/contrib/tvmop/core/fromnumeric.py b/contrib/tvmop/core/fromnumeric.py new file mode 100644 index 000000000000..e6c4c2be0814 --- /dev/null +++ b/contrib/tvmop/core/fromnumeric.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +from .. import defop +from ..utils import reduce_axes, assign_by_req + + +def _compute_sum(itype, otype, ndim, reduce1st_dim, req): + axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim] + a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=itype) + reduce_output = reduce_axes(a, axes, tvm.sum, otype) + output_placeholder, final_output = assign_by_req(reduce_output, req) + s = tvm.create_schedule(final_output.op) + return s, a, output_placeholder, final_output, [reduce_output, final_output] + + +@defop(name='sum_cpu', target='cpu', itype=['bool'], + otype=['float32', 'float64', 'int32', 'int64'], + ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1], + attrs=["reduce1st_dim", "req"]) +def _sum_cpu(itype, otype, ndim, reduce1st_dim, req): + s, a, output_placeholder, final_output, tensor_list = _compute_sum( + itype, otype, ndim, reduce1st_dim, req) + for t in tensor_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [a, output_placeholder, final_output] + + +@defop(name='sum_gpu', target='gpu', itype=['bool'], + otype=['float32', 'float64', 'int32', 'int64'], + ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1], + attrs=["reduce1st_dim", "req"]) +def _sum_gpu(itype, otype, ndim, reduce1st_dim, req): + s, a, output_placeholder, final_output, tensor_list = _compute_sum( + itype, otype, ndim, reduce1st_dim, req) + num_threads = 64 + for t in tensor_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_threads) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [a, output_placeholder, final_output] diff --git a/contrib/tvmop/core/umath.py b/contrib/tvmop/core/umath.py new file mode 100644 index 000000000000..ad099299aae5 --- /dev/null +++ b/contrib/tvmop/core/umath.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from .. import defop, AllTypes + +_bin_logic_op_map = { + 'equal': lambda a, b, *idx: a[idx] == b[idx], + 'not_equal': lambda a, b, *idx: a[idx] != b[idx], + 'greater': lambda a, b, *idx: a[idx] > b[idx], + 'less': lambda a, b, *idx: a[idx] < b[idx], + 'greater_equal': lambda a, b, *idx: a[idx] >= b[idx], + 'less_equal': lambda a, b, *idx: a[idx] <= b[idx], +} + + +def _compute_binary_logic(op, dtype, ndim): + a = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='a') + b = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='b') + c = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c') + s = tvm.create_schedule(c.op) + return s, a, b, c + + +_bin_logic_cpu_attrs = { + 'compute_func': _compute_binary_logic, + 'target': 'cpu', + 'auto_broadcast': True, + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + +_bin_logic_gpu_attrs = { + 'compute_func': _compute_binary_logic, + 'target': 'gpu', + 'auto_broadcast': True, + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + + +def _binary_logic_cpu(compute_func, op, itype, ndim): + s, a, b, c = compute_func(op, itype, ndim) + axes = [axis for axis in c.op.axis] + fused = s[c].fuse(*axes) + s[c].parallel(fused) + return s, [a, b, c] + + +def _binary_logic_gpu(compute_func, op, itype, ndim): + s, a, b, c = compute_func(op, itype, ndim) + axes = [axis for axis in c.op.axis] + fused = s[c].fuse(*axes) + bx, tx = s[c].split(fused, factor=64) + s[c].bind(bx, tvm.thread_axis('blockIdx.x')) + s[c].bind(tx, tvm.thread_axis('threadIdx.x')) + return s, [a, b, c] + + +# register binary element-wise logic ops with broadcasting supported +for op_name in _bin_logic_op_map.keys(): + defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu) + defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu) + + +# Note that `b.dtype` is hard-coded as 'float64'. +# We should always promote `a`'s elements to `b.dtype`. +_bin_scalar_logic_op_map = { + 'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b, + 'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b, + 'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b, + 'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b, + 'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b, + 'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b, +} + + +def _compute_binary_scalar_logic(op, dtype, ndim): + a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=dtype) + b = tvm.var('b', dtype='float64') + c = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c') + s = tvm.create_schedule(c.op) + return s, a, b, c + + +_bin_scalar_logic_cpu_attrs = { + 'compute_func': _compute_binary_scalar_logic, + 'target': 'cpu', + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + +_bin_scalar_logic_gpu_attrs = { + 'compute_func': _compute_binary_scalar_logic, + 'target': 'gpu', + 'itype': AllTypes + ['bool'], + 'ndim': list(range(6)) +} + + +# register binary element-wise scalar logic ops +for op_name in _bin_scalar_logic_op_map.keys(): + defop(name='{}_cpu'.format(op_name), op=op_name, + **_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu) + defop(name='{}_gpu'.format(op_name), op=op_name, + **_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu) diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py index 32d1832d13dd..39c42f4dd465 100644 --- a/contrib/tvmop/opdef.py +++ b/contrib/tvmop/opdef.py @@ -74,11 +74,12 @@ def __call__(self, *args, **kwargs): def invoke_all(self): for each_kwargs in self.arg_combination: - if (self.attrs_valid(**each_kwargs)): + if self.attrs_valid(**each_kwargs): sch, args = self.func(**each_kwargs) name = self.name \ + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \ - + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args]) + + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) + for arg in args if hasattr(arg, 'shape')]) yield sch, args, name def get_binds(self, args): @@ -107,6 +108,7 @@ def defop(name, target=None, auto_broadcast=False, **kwargs): """ assert name is not None and len(name) > 0 target = "cpu" if target is None else target + def _defop(func): opdef = OpDef(func, name, target, auto_broadcast, **kwargs) __OP_DEF__.append(opdef) diff --git a/contrib/tvmop/utils.py b/contrib/tvmop/utils.py index 329dce2148d9..39d7a8092005 100644 --- a/contrib/tvmop/utils.py +++ b/contrib/tvmop/utils.py @@ -21,16 +21,18 @@ AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"] RealTypes = ["float32", "float64", "float16"] -def assign_by_req(a, req): + +def assign_by_req(a, req, otype=None): b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) - if (req == "kAddTo"): - c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) + if req == "kAddTo": + c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx] + if otype else a[idx] + b[idx]) else: - c = tvm.compute(a.shape, lambda *idx: a[idx]) + c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx]) return b, c -def reduce_axes(X, axes, reducer): +def reduce_axes(X, axes, reducer, atype=None): def get_index(idx, ridx): j = 0 k = 0 @@ -45,5 +47,7 @@ def get_index(idx, ridx): odim = (len(ishape) + 1 - axes[0]) // 2 oshape = [tvm.var() for _ in range(odim)] ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] - ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret') + ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype) + if atype else X[get_index(idx, ridx)], + axis=ridx), name='ret') return ret diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 70af75424252..f64703092121 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -382,6 +382,7 @@ class TBlob { case mshadow::kInt32: return DLDataType{kDLInt, 32, 1}; case mshadow::kInt8: return DLDataType{kDLInt, 8, 1}; case mshadow::kInt64: return DLDataType{kDLInt, 64, 1}; + case mshadow::kBool: return DLDataType{kDLUInt, 1, 1}; default: { LOG(FATAL) << "Unknown type_flag=" << type_flag; return DLDataType(); diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 0b72865dc17a..4e3c7efa7be3 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -69,6 +69,7 @@ np.int32: 4, np.int8: 5, np.int64: 6, + np.bool_: 7, } _DTYPE_MX_TO_NP = { @@ -80,6 +81,7 @@ 4: np.int32, 5: np.int8, 6: np.int64, + 7: np.bool_, } _STORAGE_TYPE_STR_TO_ID = { @@ -2995,6 +2997,10 @@ def get_indexing_dispatch_code(key): for idx in key: if isinstance(idx, (NDArray, np.ndarray, list, tuple)): + if getattr(idx, 'dtype', None) == np.bool_: + raise TypeError('ndarray indexing does not support boolean ndarray' + ' in a tuple of indices. Only single boolean ndarray' + ' as an index is supported.') return _NDARRAY_ADVANCED_INDEXING elif sys.version_info[0] > 2 and isinstance(idx, range): return _NDARRAY_ADVANCED_INDEXING diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index e8332f1a83ef..13e2c39f3670 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,8 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take'] + 'unique', 'lcm', 'tril', 'identity', 'take', 'equal', 'not_equal', 'greater', 'less', + 'greater_equal', 'less_equal'] @set_module('mxnet.ndarray.numpy') @@ -3464,3 +3465,198 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar, + _npi.less_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar, + _npi.less_equal_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar, + _npi.greater_equal_scalar, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 7ba0f0d7d813..ab20e04785d6 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -29,7 +29,6 @@ from builtins import slice as py_slice from array import array as native_array -import sys import ctypes import warnings import numpy as _np @@ -38,7 +37,7 @@ get_oshape_of_gather_nd_op from ..ndarray._internal import _set_np_ndarray_class from . import _op as _mx_np_op -from ..base import check_call, _LIB, NDArrayHandle +from ..base import check_call, _LIB, NDArrayHandle, c_array from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types from ..context import Context from ..util import _sanity_check_params, set_module @@ -54,7 +53,8 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -86,6 +86,35 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): return hdl +def _reshape_view(a, *shape): + """Returns a **view** of this array with a new shape without altering any data. + + Parameters + ---------- + shape : tuple of int, or n ints + The new shape should not change the array size, namely + ``np.prod(new_shape)`` should be equal to ``np.prod(a.shape)``. + Some dimensions of the shape can take special value -1, which + infers the dimension of the output shape by using the remainder of the + input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + Returns + ------- + ndarray + An array with desired shape that shares data with this array. + """ + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + handle = NDArrayHandle() + check_call(_LIB.MXNDArrayReshape64(a.handle, + len(shape), + c_array(ctypes.c_int64, shape), + False, + ctypes.byref(handle))) + return ndarray(handle=handle, writable=a.writable) + + # Have to use 0 as default value for stype since pylint does not allow # importing _STORAGE_TYPE_DEFAULT from ndarray.py. def _np_ndarray_cls(handle, writable=True, stype=0): @@ -97,18 +126,6 @@ def _np_ndarray_cls(handle, writable=True, stype=0): _set_np_ndarray_class(_np_ndarray_cls) - -def _get_index(idx): - if isinstance(idx, NDArray) and not isinstance(idx, ndarray): - raise TypeError('Cannot have mx.nd.NDArray as index') - if isinstance(idx, ndarray): - return idx.as_nd_ndarray() - elif sys.version_info[0] > 2 and isinstance(idx, range): - return array(_np.arange(idx.start, idx.stop, idx.step, dtype=_np.int32)).as_nd_ndarray() - else: - return idx - - _NUMPY_ARRAY_FUNCTION_DICT = {} _NUMPY_ARRAY_UFUNC_DICT = {} @@ -272,8 +289,35 @@ def __getitem__(self, key): Overriding the method in NDArray class in a numpy fashion. Calling numpy ndarray's _get_np_basic_indexing(key) and _get_np_advanced_indexing(key). """ + # handling possible boolean indexing first ndim = self.ndim shape = self.shape + + if isinstance(key, list): + try: + new_key = _np.array(key) + if new_key.dtype == _np.bool_: + key = new_key + except Exception as err: + raise TypeError('{}'.format(str(err))) + if isinstance(key, _np.ndarray) and key.dtype == _np.bool_: + key = array(key, dtype='bool') + if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing + key_shape = key.shape + key_ndim = len(key_shape) + if ndim < key_ndim: + raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}' + .format(key_ndim, ndim)) + for i in range(key_ndim): + if key_shape[i] != shape[i]: + raise IndexError('boolean index did not match indexed array along dimension {};' + 'dimension is {} but corresponding boolean dimension is {}' + .format(i, shape[i], key_shape[i])) + remaining_dims = shape[key_ndim:] + data = _reshape_view(self, -1, *remaining_dims) + key = _reshape_view(key, -1) + return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims) + if ndim == 0: if key != (): raise IndexError('scalar tensor can only accept `()` as index') @@ -502,66 +546,30 @@ def __rpow__(self, other): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.equal(self, other) - elif isinstance(other, numeric_types): - return _npi.equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return equal(self, other) def __hash__(self): raise NotImplementedError def __ne__(self, other): """x.__ne__(y) <=> x != y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.not_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.not_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return not_equal(self, other) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.greater(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return greater(self, other) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.greater_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return greater_equal(self, other) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.less(self, other) - elif isinstance(other, numeric_types): - return _npi.less_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return less(self, other) def __le__(self, other): """x.__le__(y) <=> x <= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, ndarray): - return _npi.less_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.less_equal_scalar(self, float(other)) - else: - raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) + return less_equal(self, other) def __bool__(self): num_elements = self.size @@ -694,7 +702,7 @@ def __repr__(self): if 'dtype=' in array_str: if dtype == _np.float32: array_str = array_str[:array_str.rindex(',')] + ')' - elif dtype != _np.float32: + elif dtype not in (_np.float32, _np.bool_): array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__) context = self.context @@ -4980,3 +4988,195 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _mx_nd_np.hypot(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _mx_nd_np.equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.not_equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.greater(x1, x2, out) + + +@set_module('mxnet.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _mx_nd_np.less(x1, x2, out) + + +@set_module('mxnet.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _mx_nd_np.greater_equal(x1, x2, out) + + +@set_module('mxnet.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : ndarrays or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : ndarray or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _mx_nd_np.less_equal(x1, x2, out) diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py index 920897efc80b..b2d0dd96d324 100644 --- a/python/mxnet/numpy/utils.py +++ b/python/mxnet/numpy/utils.py @@ -22,7 +22,8 @@ import numpy as onp -__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 'pi'] +__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', + 'bool', 'bool_', 'pi', 'inf', 'nan'] float16 = onp.float16 float32 = onp.float32 @@ -31,5 +32,9 @@ int32 = onp.int32 int8 = onp.int8 int64 = onp.int64 +bool_ = onp.bool_ +bool = onp.bool pi = onp.pi +inf = onp.inf +nan = onp.nan diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 9c055c401b31..4a42d3af8de4 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,8 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take'] + 'unique', 'lcm', 'tril', 'identity', 'take', + 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] def _num_outputs(sym): @@ -138,63 +139,27 @@ def __deepcopy__(self, _): def __eq__(self, other): """x.__eq__(y) <=> x == y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.equal(self, other) - elif isinstance(other, numeric_types): - return _npi.equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return equal(self, other) def __ne__(self, other): """x.__ne__(y) <=> x != y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.not_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.not_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return not_equal(self, other) def __gt__(self, other): """x.__gt__(y) <=> x > y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.greater(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return greater(self, other) def __ge__(self, other): """x.__ge__(y) <=> x >= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.greater_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.greater_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return greater_equal(self, other) def __lt__(self, other): """x.__lt__(y) <=> x < y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.less(self, other) - elif isinstance(other, numeric_types): - return _npi.less_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return less(self, other) def __le__(self, other): """x.__le__(y) <=> x <= y""" - # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported - if isinstance(other, _Symbol): - return _npi.less_equal(self, other) - elif isinstance(other, numeric_types): - return _npi.less_equal_scalar(self, float(other)) - else: - raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) + return less_equal(self, other) def __len__(self): raise NotImplementedError @@ -3552,4 +3517,199 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax return _npi.unique(ar, return_index, return_inverse, return_counts, axis) +@set_module('mxnet.symbol.numpy') +def equal(x1, x2, out=None): + """ + Return (x1 == x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + not_equal, greater_equal, less_equal, greater, less + Examples + -------- + >>> np.equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.equal(1, np.ones(1)) + array([ True]) + """ + return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def not_equal(x1, x2, out=None): + """ + Return (x1 != x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.not_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.not_equal(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def greater(x1, x2, out=None): + """ + Return the truth value of (x1 > x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar, + _npi.less_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def less(x1, x2, out=None): + """ + Return the truth value of (x1 < x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.less(1, np.ones(1)) + array([False]) + """ + return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def greater_equal(x1, x2, out=None): + """ + Return the truth value of (x1 >= x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.greater_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[ True, True, True], + [ True, True, True]]) + >>> np.greater_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar, + _npi.less_equal_scalar, out) + + +@set_module('mxnet.symbol.numpy') +def less_equal(x1, x2, out=None): + """ + Return the truth value of (x1 <= x2) element-wise. + Parameters + ---------- + x1, x2 : _Symbol or scalars + Input arrays. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + out : Dummy parameter, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + Returns + ------- + out : _Symbol or scalar + Output array of type bool, element-wise comparison of `x1` and `x2`. + This is a scalar if both `x1` and `x2` are scalars. + See Also + -------- + equal, greater, greater_equal, less, less_equal + Examples + -------- + >>> np.less_equal(np.ones(2, 1)), np.zeros(1, 3)) + array([[False, False, False], + [False, False, False]]) + >>> np.less_equal(1, np.ones(1)) + array([True]) + """ + return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar, + _npi.greater_equal_scalar, out) + + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index abdf57039c3a..53b5e456cc70 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -50,6 +50,7 @@ from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol from .util import use_np # pylint: disable=unused-import +from .runtime import Features def default_context(): @@ -2225,3 +2226,10 @@ def collapse_sum_like(a, shape): def is_cd_run(): """Checks if the test is running as part of a Continuous Delivery run""" return os.environ.get("CD_JOB", 0) == "1" + + +_features = Features() + + +def has_tvm_ops(): + return _features.is_enabled("TVM_OP") diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index cc21dd242a2d..d1c3132b79a7 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -281,7 +281,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { CHECK_EQ(storage_type(), kDefaultStorage); NDArray ret = this->Detach(); size_t length = shape_.ProdShape(1, shape_.ndim()); - MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); }); ret.reuse_ = false; diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 1a699b12d76d..653dec84563a 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -36,6 +36,15 @@ template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { + if (from.type_flag_ == mshadow::kBool || to->type_flag_ == mshadow::kBool) { + CHECK_EQ(from.type_flag_, to->type_flag_) << "Only supports copying data between" + " two boolean tensors."; + const index_t size = from.Size(); + CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(bool) + << " bytes, to: " << to->Size() * sizeof(bool) << " bytes."; + common::ParallelCopy(to->dptr(), from.dptr(), size); + return; + } MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { const index_t size = static_cast(from.Size()); diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 0ab40a794198..6439c417bfe3 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -44,7 +44,7 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); @@ -57,7 +57,7 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); @@ -70,12 +70,16 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { if (from_ctx.dev_id == to_ctx.dev_id) { mshadow::Stream* s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { mshadow::Copy(to->FlatTo1D(s), from.FlatTo1D(s), s); } else { + CHECK_NE(from.type_flag_, mshadow::kBool) + << "Copying boolean ndarray across devices is not supported"; + CHECK_NE(to->type_flag_, mshadow::kBool) + << "Copying boolean ndarray across devices is not supported"; MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 0afc57f6ab54..a54cc917776d 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -35,7 +35,7 @@ bool BooleanMaskType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0) != -1; + return in_attrs->at(0) != -1 && in_attrs->at(1) != -1 && out_attrs->at(0) != -1; } bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, @@ -115,7 +115,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, std::vector prefix_sum(idx_size, 0); size_t valid_num = 0; // Calculate prefix sum - MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), DType, { DType* idx_dptr = idx.data().dptr(); for (size_t i = 0; i < idx_size; i++) { prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1]; @@ -154,7 +154,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { size_t input_size = igrad_data.shape().Size(); size_t idx_size = idx.shape()[0]; size_t col_size = input_size / idx_size; @@ -179,6 +179,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_boolean_mask) +.add_alias("_npi_boolean_mask") .describe(R"code( Given an n-d NDArray data, and a 1-d NDArray index, the operator produces an un-predeterminable shaped n-d NDArray out, diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index c4a06d25d70a..71d91c63f64e 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -66,7 +66,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); @@ -129,7 +129,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; - MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d90aa268195a..b46ce8a598d9 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -491,6 +491,17 @@ struct AccType { .add_enum("int64", mshadow::kInt64) +#define MXNET_ADD_ALL_TYPES_WITH_BOOL \ + .add_enum("float32", mshadow::kFloat32) \ + .add_enum("float64", mshadow::kFloat64) \ + .add_enum("float16", mshadow::kFloat16) \ + .add_enum("uint8", mshadow::kUint8) \ + .add_enum("int8", mshadow::kInt8) \ + .add_enum("int32", mshadow::kInt32) \ + .add_enum("int64", mshadow::kInt64) \ + .add_enum("bool", mshadow::kBool) + + /* \brief Compute flattened index given coordinates and shape. */ template MSHADOW_XINLINE index_t ravel(const Shape& coord, const Shape& shape) { @@ -597,6 +608,11 @@ template MSHADOW_CINLINE void copy(mshadow::Stream *s, const TBlob& to, const TBlob& from) { CHECK_EQ(from.Size(), to.Size()); CHECK_EQ(from.dev_mask(), to.dev_mask()); + if (from.type_flag_ == mshadow::kBool || to.type_flag_ == mshadow::kBool) { + CHECK_EQ(from.type_flag_, to.type_flag_) << "Only supports copying between boolean ndarrays."; + mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); + return; + } MSHADOW_TYPE_SWITCH(to.type_flag_, DType, { if (to.type_flag_ == from.type_flag_) { mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index b1348de72de5..816091a48b48 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -27,6 +27,7 @@ #include #include +#include #include "../nn/moments-inl.h" #include "../tensor/broadcast_reduce_op.h" @@ -86,7 +87,6 @@ struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter>& axis, bool keepdims) { - // TODO(junwu): improve the logic // If input is a scalar, output should be a scalar too if (ishape.ndim() == 0) { if (axis.has_value()) { @@ -210,6 +210,10 @@ inline bool NeedSafeAcc(int itype, int otype) { return safe_acc_hint && rule; } +void TVMOpReduce(const OpContext& ctx, const TBlob& input, + const dmlc::optional>& axis, + const TBlob& output, const OpReqType req, const std::string& reducer_name); + template void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, @@ -217,6 +221,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (req[0] == kNullOp) return; const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; @@ -230,6 +235,18 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, }); return; } + CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place"; + // If boolean ndarray, use the kernel generated by TVM + if (inputs[0].type_flag_ == mshadow::kBool) { + std::string reducer_name; + if (std::is_same::value) { + reducer_name = "sum"; + } else { + LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays"; + } + TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name); + return; + } if (param.axis.has_value() && param.axis.value().ndim() == 0) { UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); } @@ -279,6 +296,8 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + CHECK_NE(outputs[0].type_flag_, kBool) << "reduce operators do not support gradient calculation " + "for input tensors of boolean type."; const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); TShape small; if (param.keepdims) { diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 774bc11f5de8..fdda792a9ed8 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -23,6 +23,10 @@ * \brief CPU Implementation of broadcast and reduce functions based on value. */ +#if MXNET_USE_TVM_OP +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + #include "np_broadcast_reduce_op.h" namespace mxnet { @@ -40,7 +44,17 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, const NumpyReduceAxesParam ¶m = nnvm::get(attrs.parsed); if (param.dtype.has_value()) { + if (in_attrs->at(0) == mshadow::kBool) { + CHECK(param.dtype.value() == mshadow::kInt32 + || param.dtype.value() == mshadow::kInt64 + || param.dtype.value() == mshadow::kFloat32 + || param.dtype.value() == mshadow::kFloat64) << "Only support the following output " + "dtypes when input dtype is bool: " + "int32, int64, float32, float64."; + } TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); } else { TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); @@ -49,6 +63,63 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; } +#if MXNET_USE_TVM_OP +static constexpr int max_reduce_ndim = 5; +TBlob PrependAxes(const TBlob& src, const int dst_ndim); +#endif // MXNET_USE_TVM_OP + +void TVMOpReduce(const OpContext& ctx, + const TBlob& input, + const dmlc::optional>& axis, + const TBlob& output, + const OpReqType req, + const std::string& reducer_name) { +#if MXNET_USE_TVM_OP + CHECK_GE(input.ndim(), output.ndim()); + CHECK_LE(input.ndim(), max_reduce_ndim) << "TVMOpReduce only supports ndim <= " + << max_reduce_ndim; + + const TBlob expanded_output = (input.ndim() == output.ndim() ? + output : output.reshape(NumpyReduceAxesShapeImpl(input.shape_, axis, true))); + CHECK_EQ(input.ndim(), expanded_output.ndim()); + int reduce1st_dim = 0; + if (input.ndim() > 0 && input.size(0) != expanded_output.size(0)) { + reduce1st_dim = 1; + } + // collapse consecutive dimensions where reduction are performed or not performed + std::vector ishape_vec; + for (int i = 0; i < input.ndim(); ++i) { + if (i == 0 || ((input.size(i) != expanded_output.size(i)) + != (input.size(i-1) != expanded_output.size(i-1)))) { + ishape_vec.push_back(input.size(i)); + } else { + ishape_vec.back() *= input.size(i); + } + } + // append axes after collapsed ishape to reach the max ndim allowed + for (int i = ishape_vec.size(); i < max_reduce_ndim; ++i) { + ishape_vec.push_back(1); + } + std::vector oshape_vec; + for (size_t i = reduce1st_dim; i < ishape_vec.size(); i += 2) { + oshape_vec.push_back(ishape_vec[i]); + } + TShape ishape(ishape_vec.begin(), ishape_vec.end()), oshape(oshape_vec.begin(), oshape_vec.end()); + TBlob input_tvm = input.reshape(ishape); + TBlob output_tvm = output.reshape(oshape); + const std::string ctx_name = + (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU) ? "cpu" : "gpu"; + std::ostringstream func_name; + func_name << reducer_name << "_" + << (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU ? "cpu" : "gpu") + << "reduce1st_dim_" << reduce1st_dim + << "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo"); + tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm}); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag to enable TVM-generated kernels."; +#endif // MXNET_USE_TVM_OP +} + NNVM_REGISTER_OP(_np_sum) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a786d1db5892..9553d5f69d08 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,6 +23,12 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ +#if MXNET_USE_TVM_OP +#include +#include +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + #include "../tensor/elemwise_binary_broadcast_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -113,6 +119,22 @@ NNVM_REGISTER_OP(_npi_lcm) .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -280,21 +302,222 @@ NNVM_REGISTER_OP(_backward_npi_hypot) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInferType", ElemwiseIntType<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") -.set_attr("FCompute", BinaryScalarOp::Compute); +static constexpr char func_equal_cpu[] = "equal_cpu"; +static constexpr char func_equal_gpu[] = "equal_gpu"; +static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; +static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; +static constexpr char func_greater_cpu[] = "greater_cpu"; +static constexpr char func_greater_gpu[] = "greater_gpu"; +static constexpr char func_less_cpu[] = "less_cpu"; +static constexpr char func_less_gpu[] = "less_gpu"; +static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; +static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; +static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; +static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; + +bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; + TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +TBlob PrependAxes(const TBlob& src, const int dst_ndim) { + CHECK_LE(src.shape_.ndim(), dst_ndim); + const int src_ndim = src.shape_.ndim(); + if (src_ndim == dst_ndim) return src; + mxnet::TShape dst_shape(dst_ndim, 1); + for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { + dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; + } + return src.reshape(dst_shape); +} + +struct TVMBinaryBroadcastCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; + std::vector type_codes; + std::vector values; + + const int ondim = outputs[0].shape_.ndim(); + const size_t num_args = inputs.size() + outputs.size(); + type_codes.resize(num_args); + values.resize(num_args); + for (size_t i = 0; i < num_args; ++i) { + tblobs[i] = PrependAxes(tblobs[i], ondim); + type_codes[i] = kArrayHandle; + values[i].v_handle = const_cast(&(tblobs[i].dltensor())); + } + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); + +#if MXNET_USE_CUDA +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); +#endif // MXNET_USE_CUDA + +bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +struct TVMBinaryBroadcastScalarCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], outputs[0]}; + std::vector type_codes; + std::vector values; + + const size_t num_args = 3; // one input tensor, one scalar param, and one output + type_codes.resize(num_args); + values.resize(num_args); + + // input tensor setup + type_codes[0] = kArrayHandle; + values[0].v_handle = const_cast(&(tblobs[0].dltensor())); + + // scalar param + type_codes[1] = kDLFloat; + values[1].v_float64 = nnvm::get(attrs.parsed); + + // output tensor + type_codes[2] = kArrayHandle; + values[2].v_handle = const_cast(&(tblobs[1].dltensor())); + + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"data"}; \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_cpu}) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("scalar", "float", "scalar input") + +static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; +static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; +static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; +static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; +static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; +static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; +static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; +static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; +static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; +static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; +static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; +static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal_scalar); + +#if MXNET_USE_CUDA +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal_scalar); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal_scalar); +#endif // MXNET_USE_CUDA } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index b81cd78ad507..d0bcdda0062e 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -59,6 +59,7 @@ IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int8_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t); IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t); +IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(bool); /*! * \brief Init variable used to facilitate registering a tunable operator during @@ -84,6 +85,16 @@ struct static_init_var { __macro$(__VA_ARGS__, int32_t); \ __macro$(__VA_ARGS__, int64_t); +#define MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(__macro$, ...) \ + __macro$(__VA_ARGS__, float); \ + __macro$(__VA_ARGS__, double); \ + __macro$(__VA_ARGS__, mshadow::half::half_t); \ + __macro$(__VA_ARGS__, uint8_t); \ + __macro$(__VA_ARGS__, int8_t); \ + __macro$(__VA_ARGS__, int32_t); \ + __macro$(__VA_ARGS__, int64_t); \ + __macro$(__VA_ARGS__, bool) + #define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \ namespace mxnet_op { \ template<> std::vector mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = \ @@ -183,9 +194,15 @@ struct static_init_var { #define IMPLEMENT_UNARY_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$) +#define IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$) + #define IMPLEMENT_BLANK_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$) +#define IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$) + #define IMPLEMENT_UNARY_WORKLOAD_BWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_BWD, __op$) @@ -206,7 +223,7 @@ struct static_init_var { * integer value */ OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ns_ = 5000; -IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::identity); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal); // NOLINT() @@ -370,8 +387,8 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() -IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() -IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() /*! * \brief Tuner objects, *not* automatically generated diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc index e3c2e0e898d9..cd433e00a770 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc @@ -30,7 +30,6 @@ namespace mxnet { namespace op { MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_equal) -.add_alias("_npi_equal") .describe(R"code(Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. Example:: @@ -49,7 +48,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_not_equal) -.add_alias("_npi_not_equal") .describe(R"code(Returns the result of element-wise **not equal to** (!=) comparison operation with broadcasting. Example:: @@ -68,7 +66,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater) -.add_alias("_npi_greater") .describe(R"code(Returns the result of element-wise **greater than** (>) comparison operation with broadcasting. Example:: @@ -87,7 +84,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater_equal) -.add_alias("_npi_greater_equal") .describe(R"code(Returns the result of element-wise **greater than or equal to** (>=) comparison operation with broadcasting. Example:: @@ -106,7 +102,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser) -.add_alias("_npi_less") .describe(R"code(Returns the result of element-wise **lesser than** (<) comparison operation with broadcasting. Example:: @@ -125,7 +120,6 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser_equal) -.add_alias("_npi_less_equal") .describe(R"code(Returns the result of element-wise **lesser than or equal to** (<=) comparison operation with broadcasting. Example:: diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc index 87ba394c99b2..17e76153ebb2 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc @@ -71,32 +71,26 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_equal_scalar, mshadow_op::eq) -.add_alias("_npi_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_EqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_not_equal_scalar, mshadow_op::ne) -.add_alias("_npi_not_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_NotEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_scalar, mshadow_op::gt) -.add_alias("_npi_greater_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_equal_scalar, mshadow_op::ge) -.add_alias("_npi_greater_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_GreaterEqualScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_scalar, mshadow_op::lt) -.add_alias("_npi_less_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserScalar"); MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_equal_scalar, mshadow_op::le) -.add_alias("_npi_less_equal_scalar") .set_attr("FGradient", MakeZeroGradNodes) .add_alias("_LesserEqualScalar"); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index b50f89b8c7fa..22e7652a4019 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -341,7 +341,7 @@ class UnaryOp : public OpBase { break; case kAddTo: { Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); }); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index f3c405d7103c..8e8896e5c014 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -57,7 +57,7 @@ struct InitOpParam : public dmlc::Parameter { .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." "Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) - MXNET_ADD_ALL_TYPES + MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Target data type."); } }; @@ -342,12 +342,12 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp if (val == 0) { if (req != kAddTo) { if (b.dev_mask() == cpu::kDevMask && size < 50000) { - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { memset(b.dptr_, 0, size * sizeof(DType)); }); } else { // Optimize common use-case of filling with ones - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, Req>, xpu>::Launch( s, b.Size(), b.dptr()); @@ -357,7 +357,7 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp } } else if (is_integer && val == 1) { // Optimize common use-case of filling with ones - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, xpu>::Launch( s, b.Size(), b.dptr()); @@ -365,7 +365,7 @@ void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueTyp }); } else { // Generic fill kernel from variable - MSHADOW_TYPE_SWITCH(b.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { mxnet_op::Kernel, xpu>::Launch( s, b.Size(), b.dptr(), static_cast(val)); diff --git a/src/operator/tvmop/op_module.cc b/src/operator/tvmop/op_module.cc index d1d1c1d45d53..ea7f73069698 100644 --- a/src/operator/tvmop/op_module.cc +++ b/src/operator/tvmop/op_module.cc @@ -72,6 +72,9 @@ PackedFunc GetFunction(const std::shared_ptr &module, case mshadow::kInt64: func_name << "int64"; break; + case mshadow::kBool: + func_name << "bool"; + break; default: LOG(FATAL) << "Unknown dtype " << arg.type_flag_; } @@ -82,7 +85,7 @@ PackedFunc GetFunction(const std::shared_ptr &module, void TVMOpModule::Call(const std::string &func_name, const mxnet::OpContext& ctx, - const std::vector &args) { + const std::vector &args) const { std::vector type_codes; std::vector values; @@ -112,6 +115,28 @@ void TVMOpModule::Call(const std::string &func_name, #endif } +void TVMOpModule::CallEx(const std::string &func_name, + const mxnet::OpContext& ctx, + const std::vector& tblobs, + TVMArgs tvm_args) const { + TVMRetValue rv; + +#if MXNET_USE_CUDA + int dev_type = (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU; + int dev_id = ctx.run_ctx.ctx.dev_id; + if (dev_type == kDLGPU) { + void *stream = static_cast(ctx.run_ctx.get_stream()->stream_); + TVMSetStream(dev_type, dev_id, stream); + } +#endif + GetFunction(module_ptr_, func_name, tblobs).CallPacked(tvm_args, &rv); +#if MXNET_USE_CUDA + if (dev_type == kDLGPU) { + TVMSetStream(dev_type, dev_id, nullptr); + } +#endif +} + } // namespace runtime } // namespace tvm #endif // MXNET_USE_TVM_OP diff --git a/src/operator/tvmop/op_module.h b/src/operator/tvmop/op_module.h index 04e97ef51f4e..d28dd1b5b0c2 100644 --- a/src/operator/tvmop/op_module.h +++ b/src/operator/tvmop/op_module.h @@ -36,6 +36,7 @@ namespace tvm { namespace runtime { +class TVMArgs; class Module; class TVMOpModule { public: @@ -44,7 +45,22 @@ class TVMOpModule { void Call(const std::string& func_name, const mxnet::OpContext& ctx, - const std::vector& args); + const std::vector& args) const; + + /*! + * \brief Launch operator kernels which have been pre-compiled into a lib file + * by TVM compiler. + * \param func_name Function name that corresponds to the operator kernel + * \param ctx Operator context that includes device and stream information. + * \param tblobs Tensor blobs whose dtype and shape information are extracted + * to construct the function name. Each configuration of dtype and shape has + * a unique kernel. + * \param tvm_args Arguments to be passed to kernel function. + */ + void CallEx(const std::string &func_name, + const mxnet::OpContext& ctx, + const std::vector& tblobs, + TVMArgs tvm_args) const; static TVMOpModule *Get() { static TVMOpModule inst; diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 62ea38fc0c13..c6e406138701 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -119,18 +119,18 @@ def test_np_loss_ndarray(): weighting = np.array([0.5, 1, 0.5, 1]) loss = gluon.loss.L1Loss() - assert np.sum(loss(output, label)) == 6. + assert float(np.sum(loss(output, label))) == 6. loss = gluon.loss.L1Loss(weight=0.5) - assert np.sum(loss(output, label)) == 3. + assert float(np.sum(loss(output, label))) == 3. loss = gluon.loss.L1Loss() - assert np.sum(loss(output, label, weighting)) == 5. + assert float(np.sum(loss(output, label, weighting))) == 5. loss = gluon.loss.L2Loss() - assert np.sum(loss(output, label)) == 7. + assert float(np.sum(loss(output, label))) == 7. loss = gluon.loss.L2Loss(weight=0.25) - assert np.sum(loss(output, label)) == 1.75 + assert float(np.sum(loss(output, label))) == 1.75 loss = gluon.loss.L2Loss() - assert np.sum(loss(output, label, weighting)) == 6 + assert float(np.sum(loss(output, label, weighting))) == 6 output = np.array([[0, 2], [1, 4]]) label = np.array([0, 1]) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 3a5e72b53d58..20b964c96a30 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -19,13 +19,14 @@ from __future__ import absolute_import from __future__ import division import os +import unittest import numpy as _np import mxnet as mx from mxnet import np, npx, autograd from mxnet.gluon import HybridBlock -from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np from common import with_seed, TemporaryDirectory -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops, assert_exception from mxnet.ndarray.ndarray import py_slice from mxnet.base import integer_types import scipy.stats as ss @@ -34,8 +35,23 @@ @with_seed() @use_np def test_np_empty(): - dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, None] - expected_dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, np.float32] + # (input dtype, expected output dtype) + dtype_pairs = [ + (np.int8, np.int8), + (np.int32, np.int32), + (np.float16, np.float16), + (np.float32, np.float32), + (np.float64, np.float64), + (np.bool_, np.bool_), + (np.bool, np.bool_), + ('int8', np.int8), + ('int32', np.int32), + ('float16', np.float16), + ('float32', np.float32), + ('float64', np.float64), + ('bool', np.bool_), + (None, np.float32), + ] orders = ['C', 'F', 'A'] shapes = [ (), @@ -49,7 +65,7 @@ def test_np_empty(): (1, 1, 1, 1), ] ctxes = [npx.current_context(), None] - for dtype, expected_dtype in zip(dtypes, expected_dtypes): + for dtype, expected_dtype in dtype_pairs: for shape in shapes: for order in orders: for ctx in ctxes: @@ -65,7 +81,8 @@ def test_np_empty(): @with_seed() @use_np def test_np_array_creation(): - dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, _np.bool, _np.bool_, + 'int8', 'int32', 'float16', 'float32', 'float64', 'bool', None] objects = [ [], (), @@ -76,7 +93,7 @@ def test_np_array_creation(): for dtype in dtypes: for src in objects: mx_arr = np.array(src, dtype=dtype) - assert mx_arr.context == mx.current_context() + assert mx_arr.ctx == mx.current_context() if isinstance(src, mx.nd.NDArray): np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) else: @@ -110,6 +127,8 @@ def check_zero_array_creation(shape, dtype): if dtype is None: assert mx_out.dtype == _np.float32 assert np_out.dtype == _np.float64 + else: + assert mx_out.dtype == np_out.dtype shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()] shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)] @@ -132,6 +151,10 @@ def check_zero_array_creation(shape, dtype): y = test_zeros_output_type(x) assert type(y[1]) == np.ndarray + for shape in shapes: + for dtype in [_np.bool, bool, _np.bool, 'bool']: + check_zero_array_creation(shape, dtype) + @with_seed() @use_np @@ -158,6 +181,8 @@ def check_ones_array_creation(shape, dtype): if dtype is None: assert mx_out.dtype == _np.float32 assert np_out.dtype == _np.float64 + else: + assert mx_out.dtype == np_out.dtype shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()] shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)] @@ -180,6 +205,10 @@ def check_ones_array_creation(shape, dtype): y = test_ones_output_type(x) assert type(y[1]) == np.ndarray + for shape in shapes: + for dtype in [_np.bool, bool, _np.bool, 'bool']: + check_ones_array_creation(shape, dtype) + @with_seed() @use_np @@ -235,13 +264,18 @@ def test_np_ndarray_binary_element_wise_ops(): '/': _np.divide, 'mod': _np.mod, 'pow': _np.power, - '==': _np.equal, - '>': _np.greater, - '>=': _np.greater_equal, - '<': _np.less, - '<=': _np.less_equal } + if has_tvm_ops(): + np_op_map.update({ + '==': _np.equal, + '!=': _np.not_equal, + '>': _np.greater, + '>=': _np.greater_equal, + '<': _np.less, + '<=': _np.less_equal + }) + def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) @@ -309,10 +343,16 @@ def hybrid_forward(self, F, x, *args): return x == self._scalar if not self._reverse else self._scalar == x else: return x == args[0] + elif self._op == '!=': + if self._scalar is not None: + return x != self._scalar if not self._reverse else self._scalar != x + else: + return x != args[0] else: print(self._op) assert False + logic_ops = ['==', '!=', '>', '<', '>=', '<='] @use_np def check_binary_op_result(shape1, shape2, op, dtype=None): if shape1 is None: @@ -348,6 +388,8 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray()) assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape + if op in logic_ops: + assert np_out.dtype == mx_out.dtype assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) else: get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse) @@ -360,6 +402,8 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): mx_out = get_mx_ret(mx_input1.as_np_ndarray()) assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape + if op in logic_ops: + assert np_out.dtype == mx_out.dtype assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) dtypes = [_np.float32, _np.float64, None] @@ -933,6 +977,87 @@ def test_np_multinomial(): mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1) +@with_seed() +@unittest.skipUnless(has_tvm_ops(), "Comparison ops are implemented using TVM") +@use_np +def test_np_ndarray_boolean_indexing(): + def test_single_bool_index(): + # adapted from numpy's test_indexing.py + # Single boolean index + a = np.array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=np.int32) + assert same(a[np.array(True, dtype=np.bool_)].asnumpy(), a[None].asnumpy()) + assert same(a[np.array(False, dtype=np.bool_)].asnumpy(), a[None][0:0].asnumpy()) + + def test_boolean_catch_exception(): + # adapted from numpy's test_indexing.py + arr = np.ones((5, 4, 3)) + + index = np.array([True], dtype=np.bool_) + assert_exception(arr.__getitem__, IndexError, index) + + index = np.array([False] * 6, dtype=np.bool_) + assert_exception(arr.__getitem__, IndexError, index) + + index = np.zeros((4, 4), dtype=bool) + assert_exception(arr.__getitem__, IndexError, index) + + assert_exception(arr.__getitem__, TypeError, (slice(None), index)) + + def test_boolean_indexing_onedim(): + # adapted from numpy's test_indexing.py + # Indexing a 2-dimensional array with + # boolean array of length one + a = np.array([[0., 0., 0.]]) + b = np.array([True], dtype=bool) + assert same(a[b].asnumpy(), a.asnumpy()) + + def test_boolean_indexing_twodim(): + # adapted from numpy's test_indexing.py + # Indexing a 2-dimensional array with + # 2-dimensional boolean array + a = np.array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=np.int32) + b = np.array([[ True, False, True], + [False, True, False], + [ True, False, True]], dtype=np.bool_) + assert same(a[b].asnumpy(), _np.array([1, 3, 5, 7, 9], dtype=a.dtype)) + assert same(a[b[1]].asnumpy(), _np.array([[4, 5, 6]], dtype=a.dtype)) + assert same(a[b[0]].asnumpy(), a[b[2]].asnumpy()) + + def test_boolean_indexing_list(): + # adapted from numpy's test_indexing.py + a = np.array([1, 2, 3], dtype=np.int32) + b = [True, False, True] + # Two variants of the test because the first takes a fast path + assert same(a[b].asnumpy(), _np.array([1, 3], dtype=a.dtype)) + (a[None, b], [[1, 3]]) + + def test_boolean_indexing_autograd(): + a = np.random.uniform(size=(3, 4, 5)) + a.attach_grad() + with mx.autograd.record(): + out_mx = a[a < 0.5] + out_mx.backward() + + a_np = a.asnumpy() + out_np = a_np[a_np < 0.5] + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-4, atol=1e-5, use_broadcast=False) + + a_grad_np = _np.zeros(a.shape, dtype=a.dtype) + a_grad_np[a_np < 0.5] = 1 + assert_almost_equal(a.grad.asnumpy(), a_grad_np, rtol=1e-4, atol=1e-5, use_broadcast=False) + + test_single_bool_index() + test_boolean_catch_exception() + test_boolean_indexing_onedim() + test_boolean_indexing_twodim() + test_boolean_indexing_list() + test_boolean_indexing_autograd() + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 224125eb64f3..4f7e3e4be7f2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -29,15 +29,11 @@ from common import assertRaises, with_seed import random import scipy.stats as ss -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry -from mxnet.runtime import Features from mxnet.numpy_op_signature import _get_builtin_op +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops import platform -_features = Features() - - @with_seed() @use_np def test_np_tensordot(): @@ -240,13 +236,14 @@ def is_int(dtype): in_data_dim = random.choice([2, 3, 4]) shape = rand_shape_nd(in_data_dim, dim=3) acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', - 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + 'int8': 'int32', 'int32': 'int64', 'int64': 'int64', 'bool': 'int64'} for hybridize in [False, True]: for keepdims in [True, False]: for axis in ([i for i in range(in_data_dim)] + [(), None]): - for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: + for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']: for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: - if is_int(dtype) and not is_int(itype): + if (is_int(dtype) and not is_int(itype))\ + or (itype == 'bool' and dtype not in ('float32', 'float64', 'int32', 'int64')): continue # test gluon test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims) @@ -254,13 +251,23 @@ def is_int(dtype): test_sum.hybridize() if is_int(itype): x = _np.random.randint(-128, 128, shape, dtype=itype) - x = mx.nd.array(x) + x = np.array(x) + elif itype == 'bool': + x = _np.random.randint(0, 2, shape) < 1 + x = np.array(x, dtype='bool') else: - x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) - x = x.as_np_ndarray() - x.attach_grad() + x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype) expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) expected_ret = expected_ret.astype(dtype) + if itype == 'bool': # special handling of boolean ndarray + if has_tvm_ops(): + y = test_sum(x) + assert y.dtype == expected_ret.dtype + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5, + use_broadcast=False) + continue + + x.attach_grad() with mx.autograd.record(): y = test_sum(x) assert y.shape == expected_ret.shape @@ -1119,7 +1126,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): 'arccosh' : (lambda x: 1./(x**2 - 1.)**(1./2.), 2.0, 5.0), 'arctanh' : (lambda x: -1./(x**2 - 1.), -0.99, 0.99) } - if _features.is_enabled("TVM_OP"): + if has_tvm_ops(): funcs['rad2deg'] = (lambda x: 180. / _np.pi * _np.ones(x.shape), -1.0, 1.0) funcs['deg2rad'] = (lambda x: _np.pi / 180. * _np.ones(x.shape), -1.0, 1.0) ndim = random.choice([2, 3, 4]) @@ -1885,7 +1892,7 @@ def test_np_indices(): (2, 3, 4, 5, 6, 7) ] if platform.system() == 'Windows': - shapes = shapes[1:] #beacuse in numpy windows version, indces not support dimensions is empty tuple. + shapes = shapes[1:] # beacuse in numpy windows version, indces not support dimensions is empty tuple. for dtype in dtypes: for shape in shapes: np_out = _np.indices(dimensions=shape, dtype=dtype)