diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d1bf7dc1a44..0f4ea86528f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,11 @@ message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}") message(STATUS "CMAKE_SYSTEM_PROCESSOR ${CMAKE_SYSTEM_PROCESSOR}") message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}") + +if(USE_TVM_OP) + add_definitions(-DMXNET_USE_TVM_OP=1) +endif() + if(USE_CUDA AND NOT USE_OLDCMAKECUDA) message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'") if( @@ -739,7 +744,6 @@ if(USE_DIST_KVSTORE) endif() if(USE_TVM_OP) - add_definitions(-DMXNET_USE_TVM_OP=1) list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so) include(cmake/BuildTVM.cmake) add_subdirectory("3rdparty/tvm") diff --git a/ci/docker_cache.py b/ci/docker_cache.py index f906b0eba66c..3a2a1fb415ee 100755 --- a/ci/docker_cache.py +++ b/ci/docker_cache.py @@ -37,7 +37,7 @@ DOCKERHUB_LOGIN_NUM_RETRIES = 5 DOCKERHUB_RETRY_SECONDS = 5 DOCKER_CACHE_NUM_RETRIES = 3 -DOCKER_CACHE_TIMEOUT_MINS = 15 +DOCKER_CACHE_TIMEOUT_MINS = 45 PARALLEL_BUILDS = 10 diff --git a/docs/static_site/src/pages/api/faq/add_op_in_backend.md b/docs/static_site/src/pages/api/faq/add_op_in_backend.md index 672bf52a29ee..7595467575be 100644 --- a/docs/static_site/src/pages/api/faq/add_op_in_backend.md +++ b/docs/static_site/src/pages/api/faq/add_op_in_backend.md @@ -1,6 +1,6 @@ --- layout: page_category -title: Exception Handling in MXNet +title: A Beginner's Guide to Implementing Operators in MXNet Backend category: faq faq_c: Extend and Contribute to MXNet question: How do I implement operators in MXNet backend? diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 6e2f5fa15919..8ad5247fa263 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -653,3 +653,54 @@ def _np_trace(a, offset=0, axis1=0, axis2=1, out=None): (2, 3) """ pass + + +def _np_squeeze(a, axis=None, out=None): + """ + Remove single-dimensional entries from the shape of an array. + + Parameters + ---------- + a : ndarray + Input data. + axis : None or int or tuple of ints, optional + Selects a subset of the single-dimensional entries in the + shape. If an axis is selected with shape entry greater than + one, an error is raised. + out : ndarray, optional + Array into which the output is placed. It must have the same size + and dtype as the input array. + + Returns + ------- + squeezed : ndarray + The input array, but with all or a subset of the + dimensions of length 1 removed. It always returns a copy of `a`. + + Raises + ------ + MXNetError + If `axis` is not `None`, and an axis being squeezed is not of length 1 + + See Also + -------- + expand_dims : The inverse operation, adding singleton dimensions + reshape : Insert, remove, and combine dimensions, and resize existing ones + + Examples + -------- + >>> x = np.array([[[0], [1], [2]]]) + >>> x.shape + (1, 3, 1) + >>> np.squeeze(x).shape + (3,) + >>> np.squeeze(x, axis=0).shape + (3, 1) + >>> np.squeeze(x, axis=1).shape + Traceback (most recent call last): + ... + mxnet.base.MXNetError: cannot select an axis to squeeze out which has size=3 not equal to one + >>> np.squeeze(x, axis=2).shape + (1, 3) + """ + pass diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 9dce953004cf..e3cf09703533 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -31,7 +31,7 @@ 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', @@ -741,6 +741,53 @@ def tensordot(a, b, axes=2): return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) +@set_module('mxnet.ndarray.numpy') +def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint: disable=too-many-arguments + """ + Compute the histogram of a set of data. + + Parameters + ---------- + a : ndarray + Input data. The histogram is computed over the flattened array. + bins : int or NDArray + If `bins` is an int, it defines the number of equal-width + bins in the given range (10, by default). If `bins` is a + sequence, it defines a monotonically increasing array of bin edges, + including the rightmost edge, allowing for non-uniform bin widths. + .. versionadded:: 1.11.0 + If `bins` is a string, it defines the method used to calculate the + optimal bin width, as defined by `histogram_bin_edges`. + range : (float, float) + The lower and upper range of the bins. Required when `bins` is an integer. + Values outside the range are ignored. The first element of the range must + be less than or equal to the second. + normed : bool, optional + Not supported yet, coming soon. + weights : array_like, optional + Not supported yet, coming soon. + density : bool, optional + Not supported yet, coming soon. + """ + if normed is True: + raise NotImplementedError("normed is not supported yet...") + if weights is not None: + raise NotImplementedError("weights is not supported yet...") + if density is True: + raise NotImplementedError("density is not supported yet...") + if isinstance(bins, numeric_types): + if range is None: + raise NotImplementedError("automatic range is not supported yet...") + return _npi.histogram(a, bin_cnt=bins, range=range) + if isinstance(bins, (list, tuple)): + raise NotImplementedError("array_like bins is not supported yet...") + if isinstance(bins, str): + raise NotImplementedError("string bins is not supported yet...") + if isinstance(bins, NDArray): + return _npi.histogram(a, bins=bins) + raise ValueError("np.histogram fails with", locals()) + + @set_module('mxnet.ndarray.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" @@ -2063,11 +2110,11 @@ def logical_not(x, out=None, **kwargs): -------- >>> x= np.array([True, False, 0, 1]) >>> np.logical_not(x) - array([0., 1., 1., 0.]) + array([False, True, True, False]) >>> x = np.arange(5) >>> np.logical_not(x<3) - array([0., 0., 0., 1., 1.]) + array([False, False, False, True, True]) """ return _unary_func_helper(x, _npi.logical_not, _np.logical_not, out=out, **kwargs) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 75b7cf65325b..5ee52f14bb16 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -50,12 +50,13 @@ 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', - 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', + 'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] + # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 _NDARRAY_BASIC_INDEXING = 0 @@ -301,7 +302,7 @@ def __getitem__(self, key): except Exception as err: raise TypeError('{}'.format(str(err))) if isinstance(key, _np.ndarray) and key.dtype == _np.bool_: - key = array(key, dtype='bool') + key = array(key, dtype='bool', ctx=self.ctx) if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing key_shape = key.shape key_ndim = len(key_shape) @@ -363,6 +364,8 @@ def __setitem__(self, key, value): """ if isinstance(value, NDArray) and not isinstance(value, ndarray): raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray') + + # handle basic and advanced indexing if self.ndim == 0: if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') @@ -752,7 +755,7 @@ def detach(self): check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return _np_ndarray_cls(hdl) - def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument + def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument """ Copy of the array, cast to a specified type. @@ -1236,7 +1239,14 @@ def tile(self, *args, **kwargs): def transpose(self, *axes): # pylint: disable=arguments-differ """Permute the dimensions of an array.""" - return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) + if len(axes) == 0: + axes = None + elif len(axes) == 1: + if isinstance(axes[0], (tuple, list)): + axes = axes[0] + elif axes[0] is None: + axes = None + return _mx_np_op.transpose(self, axes=axes) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -3400,11 +3410,11 @@ def logical_not(x, out=None, **kwargs): -------- >>> x= np.array([True, False, 0, 1]) >>> np.logical_not(x) - array([0., 1., 1., 0.]) + array([False, True, True, False]) >>> x = np.arange(5) >>> np.logical_not(x<3) - array([0., 0., 0., 1., 1.]) + array([False, False, False, True, True]) """ return _mx_nd_np.logical_not(x, out=out, **kwargs) @@ -3604,6 +3614,37 @@ def tensordot(a, b, axes=2): return _mx_nd_np.tensordot(a, b, axes) +@set_module('mxnet.numpy') +def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint-disable=too-many-arguments + """ + Compute the histogram of a set of data. + + Parameters + ---------- + a : ndarray + Input data. The histogram is computed over the flattened array. + bins : int or NDArray + If `bins` is an int, it defines the number of equal-width + bins in the given range (10, by default). If `bins` is a + sequence, it defines a monotonically increasing array of bin edges, + including the rightmost edge, allowing for non-uniform bin widths. + .. versionadded:: 1.11.0 + If `bins` is a string, it defines the method used to calculate the + optimal bin width, as defined by `histogram_bin_edges`. + range : (float, float) + The lower and upper range of the bins. Required when `bins` is an integer. + Values outside the range are ignored. The first element of the range must + be less than or equal to the second. + normed : bool, optional + Not supported yet, coming soon. + weights : array_like, optional + Not supported yet, coming soon. + density : bool, optional + Not supported yet, coming soon. + """ + return _mx_nd_np.histogram(a, bins=bins, range=range, normed=normed, weights=weights, density=density) + + @set_module('mxnet.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 5a19d3dc4e2f..879ab4d56a46 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -25,7 +25,7 @@ from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import -from ..util import is_np_shape, is_np_array, set_np, reset_np +from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import from . import random # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 3eaf80a1b6fb..338c70d1e53d 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -23,7 +23,7 @@ import numpy as _np from . import _op as _mx_np_op from ...base import _LIB, SymbolHandle, numeric_types, mx_uint -from ...util import check_call, set_module +from ...util import check_call, set_module, _sanity_check_params from ...context import current_context from ..symbol import Symbol from .._internal import _set_np_symbol_class @@ -33,7 +33,7 @@ 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', @@ -181,8 +181,29 @@ def T(self): return self.transpose() # pylint: enable= invalid-name, undefined-variable - def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ - raise NotImplementedError + def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument + """ + Copy of the array, cast to a specified type. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + copy : bool, optional + Default `True`. By default, astype always returns a newly + allocated ndarray on the same context. If this is set to + `False`, and the dtype requested is the same as the ndarray's + dtype, the ndarray is returned instead of a copy. + + Returns + ------- + arr_t : ndarray + Unless `copy` is False and the other conditions for returning the input + array are satisfied (see description for `copy` input parameter), `arr_t` + is a new array of the same shape as the input array with `dtype`. + """ + _sanity_check_params('astype', ['order', 'casting', 'subok'], kwargs) + return _npi.cast(self, dtype=dtype) def dot(self, b, out=None): """Dot product of two arrays. @@ -438,7 +459,14 @@ def transpose(self, *axes): # pylint: disable=arguments-differ """The arguments are the same as for :py:func:`transpose`, with this array as data. """ - return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) + if len(axes) == 0: + axes = None + elif len(axes) == 1: + if isinstance(axes[0], (tuple, list)): + axes = axes[0] + elif axes[0] is None: + axes = None + return _mx_np_op.transpose(self, axes=axes) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -1230,6 +1258,53 @@ def tensordot(a, b, axes=2): return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) +@set_module('mxnet.symbol.numpy') +def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint: disable= too-many-arguments + """ + Compute the histogram of a set of data. + + Parameters + ---------- + a : Symbol + Input data. The histogram is computed over the flattened array. + bins : int or Symbol + If `bins` is an int, it defines the number of equal-width + bins in the given range (10, by default). If `bins` is a + sequence, it defines a monotonically increasing array of bin edges, + including the rightmost edge, allowing for non-uniform bin widths. + .. versionadded:: 1.11.0 + If `bins` is a string, it defines the method used to calculate the + optimal bin width, as defined by `histogram_bin_edges`. + range : (float, float) + The lower and upper range of the bins. Required when `bins` is an integer. + Values outside the range are ignored. The first element of the range must + be less than or equal to the second. + normed : bool, optional + Not supported yet, coming soon. + weights : array_like, optional + Not supported yet, coming soon. + density : bool, optional + Not supported yet, coming soon. + """ + if normed is True: + raise NotImplementedError("normed is not supported yet...") + if weights is not None: + raise NotImplementedError("weights is not supported yet...") + if density is True: + raise NotImplementedError("density is not supported yet...") + if isinstance(bins, numeric_types): + if range is None: + raise NotImplementedError("automatic range is not avaialble yet...") + return _npi.histogram(a, bin_cnt=bins, range=range) + if isinstance(bins, (list, tuple)): + raise NotImplementedError("array_like bins is not supported yet...") + if isinstance(bins, str): + raise NotImplementedError("string bins is not supported yet...") + if isinstance(bins, Symbol): + return _npi.histogram(a, bins) + raise ValueError("histogram fails with", locals()) + + @set_module('mxnet.symbol.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c5b3b0d8e4cd..e5f2c6c02089 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -51,6 +51,7 @@ from .symbol.numpy import _Symbol as np_symbol from .util import use_np # pylint: disable=unused-import from .runtime import Features +from .numpy_extension import get_cuda_compute_capability def default_context(): @@ -2235,10 +2236,34 @@ def has_tvm_ops(): """Returns True if MXNet is compiled with TVM generated operators. If current ctx is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported.""" built_with_tvm_op = _features.is_enabled("TVM_OP") - if current_context().device_type == 'gpu': + ctx = current_context() + if ctx.device_type == 'gpu': try: - import tvm - except ImportError: + cc = get_cuda_compute_capability(ctx) + except: # pylint: disable=bare-except + print('Failed to get CUDA compute capability for context {}. The operators ' + 'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx)) return False - return built_with_tvm_op and (int("".join(tvm.nd.gpu(0).compute_version.split('.'))) >= 53) + print('Cuda arch compute capability: sm_{}'.format(str(cc))) + return built_with_tvm_op and cc >= 53 return built_with_tvm_op + + +def is_op_runnable(): + """Returns True for all CPU tests. Returns True for GPU tests that are either of the following. + 1. Built with USE_TVM_OP=0. + 2. Built with USE_TVM_OP=1, but with compute capability >= 53.""" + ctx = current_context() + if ctx.device_type == 'gpu': + if not _features.is_enabled("TVM_OP"): + return True + else: + try: + cc = get_cuda_compute_capability(ctx) + except: # pylint: disable=bare-except + print('Failed to get CUDA compute capability for context {}. The operators ' + 'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx)) + return False + print('Cuda arch compute capability: sm_{}'.format(str(cc))) + return cc >= 53 + return True diff --git a/python/mxnet/util.py b/python/mxnet/util.py index d4e95e0c0c9c..1050fb2a481d 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -602,3 +602,64 @@ def set_np(shape=True, array=True): def reset_np(): """Deactivate NumPy shape and array semantics at the same time.""" set_np(shape=False, array=False) + + +_CUDA_SUCCESS = 0 + + +def get_cuda_compute_capability(ctx): + """Returns the cuda compute capability of the input `ctx`. + + Parameters + ---------- + ctx : Context + GPU context whose corresponding cuda compute capability is to be retrieved. + + Returns + ------- + cuda_compute_capability : int + CUDA compute capability. For example, it returns 70 for CUDA arch equal to `sm_70`. + + References + ---------- + https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549#file-cuda_check-py + """ + if ctx.device_type != 'gpu': + raise ValueError('Expecting a gpu context to get cuda compute capability, ' + 'while received ctx {}'.format(str(ctx))) + + libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + for libname in libnames: + try: + cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + ' '.join(libnames)) + + # Some constants taken from cuda.h + + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + device = ctypes.c_int() + error_str = ctypes.c_char_p() + + ret = cuda.cuInit(0) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuInit failed with erro code {}: {}' + .format(ret, error_str.value.decode())) + + ret = cuda.cuDeviceGet(ctypes.byref(device), ctx.device_id) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuDeviceGet failed with error code {}: {}' + .format(ret, error_str.value.decode())) + ret = cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuDeviceComputeCapability failed with error code {}: {}' + .format(ret, error_str.value.decode())) + return cc_major.value * 10 + cc_minor.value diff --git a/src/common/utils.h b/src/common/utils.h index 2bd6aac6f5d9..fbecc8b4e955 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -823,6 +823,21 @@ static inline std::string GetOutputName(const nnvm::NodeEntry& e) { return sym.ListOutputNames()[0]; } +inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& src) { + // convert negative axes to positive values + const int ndim = src.ndim(); + mxnet::TShape axes = src; + for (int i = 0; i < ndim; ++i) { + if (axes[i] < 0) { + axes[i] += ndim; + } + CHECK(axes[i] >= 0 && axes[i] < ndim) << "axes[" << i << "]=" + << axes[i] << " exceeds the range [" + << 0 << ", " << ndim << ")"; + } + return axes; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 653dec84563a..34429446bd62 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -36,23 +36,14 @@ template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { - if (from.type_flag_ == mshadow::kBool || to->type_flag_ == mshadow::kBool) { - CHECK_EQ(from.type_flag_, to->type_flag_) << "Only supports copying data between" - " two boolean tensors."; - const index_t size = from.Size(); - CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(bool) - << " bytes, to: " << to->Size() * sizeof(bool) << " bytes."; - common::ParallelCopy(to->dptr(), from.dptr(), size); - return; - } - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { const index_t size = static_cast(from.Size()); CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType) << " bytes, to: " << to->Size() * sizeof(DType) << " bytes."; common::ParallelCopy(to->dptr(), from.dptr(), size); } else { - MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D() = mshadow::expr::tcast(from.FlatTo1D()); }) diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 6439c417bfe3..2a1461cc8c48 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -76,10 +76,6 @@ void Copy(const TBlob &from, TBlob *to, from.FlatTo1D(s), s); } else { - CHECK_NE(from.type_flag_, mshadow::kBool) - << "Copying boolean ndarray across devices is not supported"; - CHECK_NE(to->type_flag_, mshadow::kBool) - << "Copying boolean ndarray across devices is not supported"; MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 9f78f0593ed1..35bfcd0e77b6 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -71,7 +71,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(0)[i], in_attrs->at(2)[i]); } } - // The the length of the fitrst dim of copied tensor + // The the length of the first dim of copied tensor // must equal to the size of index vector CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc index f272a8860d85..9a071c04b51c 100644 --- a/src/operator/contrib/index_copy.cc +++ b/src/operator/contrib/index_copy.cc @@ -28,12 +28,12 @@ namespace op { struct index_copy_fwd_cpu { template - static void Map(int i, + static void Map(index_t i, const DType* new_tensor, const IType* idx, DType* out_tensor, int dim_size) { - DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; + DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; const DType* new_ptr = new_tensor + i * dim_size; std::memcpy(out_ptr, new_ptr, sizeof(DType) * dim_size); } diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 2c4127b9a088..d73fa1be54a4 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -134,7 +134,7 @@ class LeakyReLUOp : public Operator { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape, in_data[leakyrelu::kData].dptr(), in_data[leakyrelu::kGamma].dptr(), diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index a941f9b25a54..471b6f395a05 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -73,6 +73,14 @@ using std::is_integral; } \ } +#define MXNET_UNARY_LOGIC_OP_NC(name, expr) \ + struct name : public mxnet_op::tunable { \ + template \ + MSHADOW_XINLINE static bool Map(DType a) { \ + return (expr); \ + } \ + } + #define MXNET_BINARY_MATH_OP(name, expr) \ struct name : public mxnet_op::tunable { \ template \ @@ -89,6 +97,14 @@ using std::is_integral; } \ } +#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ + struct name : public mxnet_op::tunable { \ + template \ + MSHADOW_XINLINE static bool Map(DType a, DType b) { \ + return (expr); \ + } \ + } + #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a)) #define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b)) @@ -335,6 +351,8 @@ MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b))); MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1)); +MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast(a)); + MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0)); @@ -347,6 +365,18 @@ MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0)); +MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false); + MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0)); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index b46ce8a598d9..950db174595e 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -703,6 +703,26 @@ struct op_with_req { const DType *input_3) { KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], input_3[i])); } + + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *in) { + KERNEL_ASSIGN(out[i], req, OP::Map(in[i])); + } + + /*! \brief inputs are two tensors with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *lhs, const DType *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief input is tensor and two scalar value with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *in, const DType value) { + KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); + } }; template diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index e8ee98c7ac89..9a8dffdfbca4 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -394,7 +394,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, @@ -463,7 +463,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, grad.dptr(), mask.dptr(), gdata.dptr()); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 55ba6510f609..7f0eaad7f872 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -189,7 +189,6 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); -bool SupportMKLDNNReshape(const NDArray &in_data, const NDArray &out_data); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 6a6e3eeeeca5..0130a44c6596 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -362,6 +362,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, const std::vector &outputs) { std::vector in_blobs(inputs.size()); std::vector in_bufs; + std::vector new_req = req; for (size_t i = 0; i < in_blobs.size(); i++) { // If the input data isn't stored in the default format, we shouldn't // call data() directly, which will change the layout of the NDArray. @@ -386,6 +387,9 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, // for inplace, we already converted & copied input above. if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) { const_cast(output).InvalidateMKLDNNData(); + if (req[i] == kWriteInplace) { + new_req[i] = kWriteTo; + } } else if (req[i] == kAddTo && output.IsMKLDNNData()) { NDArray temp = outputs[i].Reorder2Default(); temp_src.emplace_back(temp); @@ -396,7 +400,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, out_blobs[i] = output.data(); } - fn(attrs_states, ctx, in_blobs, req, out_blobs); + fn(attrs_states, ctx, in_blobs, new_req, out_blobs); for (size_t i = 0; i < out_blobs.size(); i++) { if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false); diff --git a/src/operator/nn/mkldnn/mkldnn_expand_dims.cc b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc deleted file mode 100644 index dcd85f1cf60c..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_expand_dims.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_expand_dims.cc - * \brief Implement expand_dims operator via MKL-DNN reorder primitive - * \author Wuxun Zhang -*/ - -#if MXNET_USE_MKLDNN == 100 - -#include "mkldnn_reshape-inl.h" - -namespace mxnet { -namespace op { - -class MKLDNNExpandDimsFwd : public MKLDNNReshapeFwd { - public: - explicit MKLDNNExpandDimsFwd(const OpReqType &req, - const NDArray &input, - const NDArray &output) - : MKLDNNReshapeFwd(req, input, output) {} -}; - -typedef ParamOpSign MKLDNNExpandDimsSignature; - -void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output) { - const ExpandDimParam& param = nnvm::get(attrs.parsed); - if (req == kNullOp) return; - CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - - auto fwd = GetCachedForward(param, req, input, output); - - auto ws_size = fwd.GetWorkspaceSize(); - void* ws_ptr = nullptr; - if (ws_size) { - mshadow::Stream *s = ctx.get_stream(); - mshadow::Tensor ws = ctx.requested[0] - .get_space_typed(mshadow::Shape1(ws_size), s); - ws_ptr = reinterpret_cast(ws.dptr_); - } - - fwd.Execute(input, output, req, ws_ptr); -} - -} // namespace op -} // namespace mxnet - -#endif diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h deleted file mode 100644 index 89e52cc50988..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_flatten-inl.h - * \brief Implement flatten operator by using mkldnn reorder primitive - * \author Wuxun Zhang - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ -#if MXNET_USE_MKLDNN == 100 - -#include "mkldnn_reshape-inl.h" - -namespace mxnet { -namespace op { - -class MKLDNNFlattenFwd : public MKLDNNReshapeFwd { - public: - explicit MKLDNNFlattenFwd(const OpReqType &req, const NDArray &input, const NDArray &output) - : MKLDNNReshapeFwd(req, input, output) {} -}; - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &input, - const OpReqType &req, const NDArray &output); - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc deleted file mode 100644 index 4058399ab3fe..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_flatten.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_flatten.cc - * \brief Implement flatten operator via using MKL-DNN reorder primitive - * \author Wuxun Zhang -*/ - -#if MXNET_USE_MKLDNN == 100 - -#include "mkldnn_flatten-inl.h" - -namespace mxnet { -namespace op { - -static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req, - const NDArray &input, - const NDArray &output) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; -#else - static MX_THREAD_LOCAL std::unordered_map fwds; -#endif - OpSignature key; - key.AddSign(req); - key.AddSign(input); - - auto it = fwds.find(key); - if (it == fwds.end()) { - MKLDNNFlattenFwd fwd(req, input, output); - it = AddToCache(&fwds, key, fwd); - } - return it->second; -} - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output) { - if (req == kNullOp) return; - CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - - auto fwd = GetFlattenForward(req, input, output); - auto ws_size = fwd.GetWorkspaceSize(); - void* ws_ptr = nullptr; - if (ws_size) { - mshadow::Stream *s = ctx.get_stream(); - mshadow::Tensor ws = ctx.requested[0] - .get_space_typed(mshadow::Shape1(ws_size), s); - ws_ptr = reinterpret_cast(ws.dptr_); - } - - fwd.Execute(input, output, req, ws_ptr); -} - -} // namespace op -} // namespace mxnet - -#endif diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index bba76a3cc570..23f059b32240 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -38,26 +38,10 @@ #if MXNET_USE_MKLDNN == 100 #include -#endif namespace mxnet { namespace op { -#if MXNET_USE_MKLDNN == 1 -void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); -#endif - -#if MXNET_USE_MKLDNN == 100 /* For fully connected. */ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -148,19 +132,8 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray &input, const OpReqType &req, const NDArray &output); -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); -void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); -#endif - } // namespace op } // namespace mxnet +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h index aa0f11ca7afb..8c6d38e5ab31 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -1,91 +1,61 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file mkldnn_reshape-inl.h - * \brief Function definition of mkldnn reshape operator - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ - -#if MXNET_USE_MKLDNN == 100 -#include -#include "mkldnn_base-inl.h" -#include "../../tensor/matrix_op-inl.h" - -namespace mxnet { -namespace op { - -class MKLDNNReshapeFwd { - protected: - std::shared_ptr out_; - std::shared_ptr temp_; - std::vector prims_; - bool needInvalidateInput = false; - - public: - MKLDNNReshapeFwd(const OpReqType &req, - const NDArray &input, - const NDArray &output); - int GetWorkspaceSize(); - void Execute(const NDArray &input, - const NDArray &output, - const OpReqType &req, - void* workspace = nullptr); -}; - -typedef ParamOpSign MKLDNNReshapeSignature; - -template -MKLDNNOpFwdType &GetCachedForward(const ParamType& param, - const OpReqType &req, - const NDArray &input, - const NDArray &output) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; -#else - static MX_THREAD_LOCAL std::unordered_map fwds; -#endif - MKLDNNSigatureType key(param); - key.AddSign(req); - key.AddSign(input); - key.AddSign(output); - - auto it = fwds.find(key); - if (it == fwds.end()) { - MKLDNNOpFwdType fwd(req, input, output); - it = AddToCache(&fwds, key, fwd); - } - return it->second; -} - -MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, - const OpReqType &req, - const NDArray &input, - const NDArray &output); - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_reshape-inl.h + * \brief Function definition of mkldnn reshape operator + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ + +#if MXNET_USE_MKLDNN == 100 +#include +#include "mkldnn_base-inl.h" +#include "../../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNReshapeFwd { + protected: + std::shared_ptr out_; + std::shared_ptr temp_; + std::vector prims_; + + public: + MKLDNNReshapeFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output); + int GetWorkspaceSize(); + void Execute(const NDArray &input, + const NDArray &output, + const OpReqType &req, + void* workspace = nullptr); +}; + +typedef OpSignature MKLDNNReshapeSignature; +MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input, + const NDArray &output); +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 100 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index d180125b16bb..1c1e72f1aaf9 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -24,61 +24,38 @@ */ #if MXNET_USE_MKLDNN == 100 - -#include -#include "mkldnn_reshape-inl.h" +#include "../../tensor/elemwise_unary_op.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" +#include "./mkldnn_reshape-inl.h" namespace mxnet { namespace op { -bool SupportMKLDNNReshape(const NDArray &in_data, - const NDArray &out_data) { - auto in_ndim = in_data.shape().ndim(); - auto out_ndim = out_data.shape().ndim(); - - if (in_ndim > 4 || - in_data.dtype() != mshadow::kFloat32 || - out_ndim > 4) - return false; - - return true; -} - -MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, - const NDArray &input, +MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output) { - auto engine = CpuEngine::Get()->get_engine(); - - // source + const auto engine = CpuEngine::Get()->get_engine(); auto in_mem = input.GetMKLDNNData(); - auto in_md = in_mem->get_desc(); - - // temp_ - auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md)); - temp_ = std::make_shared(temp_md, engine, nullptr); - // destination - out_ = std::make_shared(temp_md, engine, nullptr); + // Create temp memory + auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); + auto temp_type = static_cast(get_mkldnn_type(input.dtype())); + auto temp_fmt = static_cast(GetDefaultFormat(input.shape().ndim())); + auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); + out_ = std::make_shared(temp_desc, engine, nullptr); if (req == kWriteInplace) { // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with // default layout and copy from the temporal buffer back to output buffer which has the same // address with input buffer. // If the input has default layout, then nothing need to do. if (input.IsMKLDNNData()) { + temp_ = std::make_shared(temp_desc, engine, nullptr); prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back - needInvalidateInput = true; } } else if (req == kWriteTo) { - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer - needInvalidateInput = false; - } else { - prims_.push_back(mkldnn::reorder(*in_mem, *out_)); // copy directly from input to output - needInvalidateInput = false; - } + prims_.push_back(mkldnn::reorder(*in_mem, *out_)); } else { LOG(FATAL) << "not supported req type: " << req; } @@ -117,10 +94,30 @@ void MKLDNNReshapeFwd::Execute(const NDArray &input, stream->RegisterPrimArgs(prims_[i], args_map[i]); } stream->Submit(); - // invalidate mkldnn memory in input - if (needInvalidateInput) { - const_cast(input).InvalidateMKLDNNData(); + // invalidate mkldnn memory in output + const_cast(output).InvalidateMKLDNNData(); +} + +MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, + const NDArray &input, + const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + MKLDNNReshapeSignature key; + key.AddSign(req); + key.AddSign(input); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNReshapeFwd fwd(req, input, output); + it = AddToCache(&fwds, key, fwd); } + return it->second; } void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, @@ -128,24 +125,28 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray &input, const OpReqType &req, const NDArray &output) { - const ReshapeParam& param = nnvm::get(attrs.parsed); + // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to + // naive implement. + if (input.shape().ndim() > 4 || !SupportMKLDNNQuantize(input.dtype())) { + if (req != kWriteInplace) { + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, {input}, {req}, {output}); + } + return; + } if (req == kNullOp) return; CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - - auto fwd = GetCachedForward(param, req, input, output); - + auto fwd = GetReshapeForward(req, input, output); auto ws_size = fwd.GetWorkspaceSize(); void* ws_ptr = nullptr; if (ws_size) { mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor ws = ctx.requested[0] .get_space_typed(mshadow::Shape1(ws_size), s); - ws_ptr = reinterpret_cast(ws.dptr_); + ws_ptr = static_cast(ws.dptr_); } - fwd.Execute(input, output, req, ws_ptr); } + } // namespace op } // namespace mxnet #endif diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc new file mode 100644 index 000000000000..7e8951afa1d0 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_logic_op.cc + * \brief CPU Implementation of basic logic functions for elementwise numpy binary + * broadcast operator. + */ + +#if MXNET_USE_TVM_OP +#include +#include +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +static constexpr char func_equal_cpu[] = "equal_cpu"; +static constexpr char func_equal_gpu[] = "equal_gpu"; +static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; +static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; +static constexpr char func_greater_cpu[] = "greater_cpu"; +static constexpr char func_greater_gpu[] = "greater_gpu"; +static constexpr char func_less_cpu[] = "less_cpu"; +static constexpr char func_less_gpu[] = "less_gpu"; +static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; +static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; +static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; +static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; + +bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; + TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +TBlob PrependAxes(const TBlob& src, const int dst_ndim) { + CHECK_LE(src.shape_.ndim(), dst_ndim); + const int src_ndim = src.shape_.ndim(); + if (src_ndim == dst_ndim) return src; + mxnet::TShape dst_shape(dst_ndim, 1); + for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { + dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; + } + return src.reshape(dst_shape); +} + +struct TVMBinaryBroadcastCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; + std::vector type_codes; + std::vector values; + + const int ondim = outputs[0].shape_.ndim(); + const size_t num_args = inputs.size() + outputs.size(); + type_codes.resize(num_args); + values.resize(num_args); + for (size_t i = 0; i < num_args; ++i) { + tblobs[i] = PrependAxes(tblobs[i], ondim); + type_codes[i] = kArrayHandle; + values[i].v_handle = const_cast(&(tblobs[i].dltensor())); + } + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); + +#if MXNET_USE_TVM_OP + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) + +#if MXNET_USE_CUDA + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_CUDA + +#else + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", BinaryBroadcastComputeLogic) + +#endif // MXNET_USE_TVM_OP + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(less_equal); + +bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +struct TVMBinaryBroadcastScalarCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], outputs[0]}; + std::vector type_codes; + std::vector values; + + const size_t num_args = 3; // one input tensor, one scalar param, and one output + type_codes.resize(num_args); + values.resize(num_args); + + // input tensor setup + type_codes[0] = kArrayHandle; + values[0].v_handle = const_cast(&(tblobs[0].dltensor())); + + // scalar param + type_codes[1] = kDLFloat; + values[1].v_float64 = nnvm::get(attrs.parsed); + + // output tensor + type_codes[2] = kArrayHandle; + values[2].v_handle = const_cast(&(tblobs[1].dltensor())); + + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"data"}; \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("scalar", "float", "scalar input") + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal); + +static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; +static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; +static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; +static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; +static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; +static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; +static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; +static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; +static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; +static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; +static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; +static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; + +#if MXNET_USE_TVM_OP + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_scalar_cpu}) + +#if MXNET_USE_CUDA + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_scalar_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_CUDA + +#else + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) + +#endif // MXNET_USE_TVM_OP + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(less_equal); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cu b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu new file mode 100644 index 000000000000..98995a39dcb9 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_logic_op.cu + * \brief GPU Implementation of basic functions for elementwise binary + * broadcast logic operator. + */ +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + + +#if MXNET_USE_TVM_OP == 0 + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", BinaryBroadcastComputeLogic) + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_TVM_OP + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index aa81a58c2890..7e07e47b4a8f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,12 +23,6 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#if MXNET_USE_TVM_OP -#include -#include -#include "../tvmop/op_module.h" -#endif // MXNET_USE_TVM_OP - #include "../tensor/elemwise_binary_broadcast_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -302,223 +296,6 @@ NNVM_REGISTER_OP(_backward_npi_hypot) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); -static constexpr char func_equal_cpu[] = "equal_cpu"; -static constexpr char func_equal_gpu[] = "equal_gpu"; -static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; -static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; -static constexpr char func_greater_cpu[] = "greater_cpu"; -static constexpr char func_greater_gpu[] = "greater_gpu"; -static constexpr char func_less_cpu[] = "less_cpu"; -static constexpr char func_less_gpu[] = "less_gpu"; -static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; -static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; -static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; -static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; - -bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; - TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); - return true; -} - -TBlob PrependAxes(const TBlob& src, const int dst_ndim) { - CHECK_LE(src.shape_.ndim(), dst_ndim); - const int src_ndim = src.shape_.ndim(); - if (src_ndim == dst_ndim) return src; - mxnet::TShape dst_shape(dst_ndim, 1); - for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { - dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; - } - return src.reshape(dst_shape); -} - -struct TVMBinaryBroadcastCompute { - const char* func; - void operator()(const nnvm::NodeAttrs& attrs, - const mxnet::OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if MXNET_USE_TVM_OP - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor - - // prepare tblobs and TVMArgs - std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; - std::vector type_codes; - std::vector values; - - const int ondim = outputs[0].shape_.ndim(); - const size_t num_args = inputs.size() + outputs.size(); - type_codes.resize(num_args); - values.resize(num_args); - for (size_t i = 0; i < num_args; ++i) { - tblobs[i] = PrependAxes(tblobs[i], ondim); - type_codes[i] = kArrayHandle; - values[i].v_handle = const_cast(&(tblobs[i].dltensor())); - } - tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); - tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); -#else - LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " - "to enable TVM-generated kernels for operator " << func; -#endif // MXNET_USE_TVM_OP - } -}; - -#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_num_inputs(2) \ - .set_num_outputs(1) \ - .set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{"lhs", "rhs"}; \ - }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ - .set_attr("FInferType", NumpyBinaryLogicOpType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs) { \ - return std::vector >{{0, 0}, {1, 0}}; \ - }) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) \ - .set_attr("FGradient", MakeZeroGradNodes) \ - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") - -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); - -#if MXNET_USE_CUDA -#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) - -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); -#endif // MXNET_USE_CUDA - -bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - if (in_attrs->at(0) == -1) return false; - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); - return true; -} - -struct TVMBinaryBroadcastScalarCompute { - const char* func; - void operator()(const nnvm::NodeAttrs& attrs, - const mxnet::OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if MXNET_USE_TVM_OP - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor - - // prepare tblobs and TVMArgs - std::vector tblobs = {inputs[0], outputs[0]}; - std::vector type_codes; - std::vector values; - - const size_t num_args = 3; // one input tensor, one scalar param, and one output - type_codes.resize(num_args); - values.resize(num_args); - - // input tensor setup - type_codes[0] = kArrayHandle; - values[0].v_handle = const_cast(&(tblobs[0].dltensor())); - - // scalar param - type_codes[1] = kDLFloat; - values[1].v_float64 = nnvm::get(attrs.parsed); - - // output tensor - type_codes[2] = kArrayHandle; - values[2].v_handle = const_cast(&(tblobs[1].dltensor())); - - tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); - tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); -#else - LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " - "to enable TVM-generated kernels for operator " << func; -#endif // MXNET_USE_TVM_OP - } -}; - -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{"data"}; \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs) { \ - return std::vector >{{0, 0}}; \ - }) \ - .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_cpu}) \ - .set_attr("FGradient", MakeZeroGradNodes) \ - .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") - -static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; -static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; -static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; -static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; -static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; -static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; -static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; -static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; -static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; -static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; -static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; -static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal_scalar); - -#if MXNET_USE_CUDA -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_gpu}) - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal_scalar); -#endif // MXNET_USE_CUDA - MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index b454d89212e2..c980dcfaab5d 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -65,22 +65,49 @@ NNVM_REGISTER_OP(_np_copy) }) .add_argument("a", "NDArray-or-Symbol", "The input"); -#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY(__name$, __input_name$, __kernel$) \ -NNVM_REGISTER_OP(__name$) \ -.set_num_inputs(1) \ -.set_num_outputs(1) \ -.set_attr("FInferShape", ElemwiseShape<1, 1>) \ -.set_attr("FInferType", ElemwiseType<1, 1>) \ -.set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ -.set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{__input_name$}; \ - }) \ -.set_attr("FCompute", UnaryOp::Compute) \ -.add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY(__name$, __input_name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{__input_name$}; \ + }) \ + .set_attr("FCompute", UnaryOp::Compute) \ + .add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") + +bool NumpyUnaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_LOGIC(__name$, __input_name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyUnaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{__input_name$}; \ + }) \ + .set_attr("FCompute", UnaryOp::ComputeLogic) \ + .add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") // negative MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_negative, "x", mshadow_op::negation) @@ -243,11 +270,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_expm1, "x", mshadow_op::expm1) // logical_not -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_logical_not, "x", mshadow_op::nt) -.describe(R"code(Compute the truth value of NOT x element-wise. -Example:: - logical_not([-2., 0., 1.]) = [0., 1., 0.] -)code") +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_LOGIC(_npi_logical_not, "x", mshadow_op::np_logical_not) .set_attr("FGradient", MakeZeroGradNodes); // sin diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index 7f68386560f3..44743ed94be8 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -35,9 +35,9 @@ NNVM_REGISTER_OP(_npx_sigmoid) NNVM_REGISTER_OP(_np_copy) .set_attr("FCompute", UnaryOp::IdentityCompute); -#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ -NNVM_REGISTER_OP(__name$) \ -.set_attr("FCompute", UnaryOp::Compute) \ +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_attr("FCompute", UnaryOp::Compute) MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_negative, mshadow_op::negation); @@ -76,7 +76,8 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log1p, mshadow_op::log1p); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_expm1, mshadow_op::expm1); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_logical_not, mshadow_op::nt); +NNVM_REGISTER_OP(_npi_logical_not) +.set_attr("FCompute", UnaryOp::ComputeLogic); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sin, mshadow_op::sin); diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 5e25192d9298..0eefeb9cdb5d 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -29,6 +29,7 @@ #include #include "../tensor/matrix_op-inl.h" #include "../nn/concat-inl.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -59,7 +60,8 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, const NumpyTransposeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; if (ndim_is_known(param.axes)) { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], param.axes); + mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } else { mxnet::TShape axes(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 96a10561be28..83f0b1aae9b0 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -40,21 +40,53 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; + mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - mxnet::TShape ret(shp.ndim(), -1); + + int ndim = -1; + if (ndim_is_known(shp)) { + ndim = shp.ndim(); + } else if (ndim_is_known(out_shp)) { + ndim = out_shp.ndim(); + } + if (ndim < 0) { + return false; + } + if (out_shp.ndim() >= 0 && shp.ndim() >= 0) { + CHECK_EQ(out_shp.ndim(), shp.ndim()); + } + + mxnet::TShape get(ndim, -1); + mxnet::TShape ret(ndim, -1); + if (ndim_is_known(param.axes)) { - CHECK_EQ(shp.ndim(), param.axes.ndim()); - for (int i = 0; i < shp.ndim(); ++i) { - CHECK(param.axes[i] < static_cast(shp.ndim())); - ret[i] = shp[param.axes[i]]; + CHECK_EQ(ndim, param.axes.ndim()); + mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + if (ndim_is_known(shp)) { + for (int i = 0; i < ndim; ++i) { + ret[i] = shp[axes[i]]; + } + } + if (ndim_is_known(out_shp)) { + for (int i = 0; i < ndim; ++i) { + get[axes[i]] = out_shp[i]; + } } } else { - for (int i = 0; i < shp.ndim(); ++i) { - ret[i] = shp[shp.ndim()-1-i]; + if (ndim_is_known(shp)) { + for (int i = 0; i < ndim; ++i) { + ret[i] = shp[ndim - 1 - i]; + } + } + if (ndim_is_known(out_shp)) { + for (int i = 0; i < ndim; ++i) { + get[ndim - 1 - i] = out_shp[i]; + } } } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, get); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); - return shape_is_known(ret); + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); } NNVM_REGISTER_OP(_np_transpose) @@ -69,7 +101,12 @@ NNVM_REGISTER_OP(_np_transpose) if (ndim_is_known(param.axes)) { mxnet::TShape axes = mxnet::TShape(param.axes.ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { - axes[param.axes[i]] = i; + int axis = param.axes[i]; + if (axis < 0) { + axis += param.axes.ndim(); + } + CHECK(axis >= 0 && axis < param.axes.ndim()); + axes[axis] = i; } std::ostringstream os; os << axes; @@ -248,7 +285,7 @@ NNVM_REGISTER_OP(_np_squeeze) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_squeeze"}) -.add_argument("a", "NDArray-or-Symbol[]", "data to squeeze") +.add_argument("a", "NDArray-or-Symbol", "data to squeeze") .add_arguments(SqueezeParam::__FIELDS__()); bool ConcatShape(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 49d0e23bdd19..398d5a714bd1 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -209,6 +209,9 @@ struct static_init_var { #define IMPLEMENT_BINARY_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$) +#define IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$) + #define IMPLEMENT_BINARY_WORKLOAD_BWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_BWD, __op$) @@ -307,6 +310,7 @@ IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::nt); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT() @@ -378,6 +382,12 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_not_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_greater); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_greater_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_less); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_less_equal); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT() diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc index 2416c128eddd..d50f9684c22e 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc @@ -24,7 +24,7 @@ */ #if MXNET_USE_MKLDNN == 100 -#include "../../nn/mkldnn/mkldnn_flatten-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../quantization_utils.h" namespace mxnet { @@ -42,7 +42,7 @@ static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs, const Op const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); outputs[1].data().dptr()[0] = inputs[1].data().dptr()[0]; outputs[2].data().dptr()[0] = inputs[2].data().dptr()[0]; } diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 29c476d06d1c..6a612e6f1cd5 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -186,12 +186,12 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } namespace mxnet_op { -template +template struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, - const Shape &oshape, DType *lhs, DType *rhs, + const Shape &oshape, IType *lhs, IType *rhs, DType *out) { Shape coord = unravel(base, oshape); auto lidx = static_cast(dot(coord, lstride)); @@ -209,7 +209,7 @@ struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, - const Shape &oshape, DType lhs, DType *rhs, + const Shape &oshape, IType lhs, IType *rhs, DType *out) { Shape coord = unravel(base, oshape); auto lidx = static_cast(dot(coord, lstride)); @@ -306,7 +306,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); @@ -315,6 +315,36 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } } +template +void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] != kNullOp) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + outputs[0].dptr()); + }); + }); + } + } +} + template void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx, const NDArray& csr, @@ -413,11 +443,11 @@ void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx, Shape lstride = calc_stride(new_csrshape.get()); Shape rstride = calc_stride(new_dnsshape.get()); if (reverse && std::is_same::value) { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } else { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 9c1d8b17fdea..6f444aed21fe 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -497,6 +497,32 @@ class ElemwiseBinaryOp : public OpBase { } } + template + static void ComputeLogic(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] != kNullOp) { + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), + inputs[1].dptr()); + } + }); + }); + } + } + template static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index c78841641214..02b005eed995 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -244,6 +244,26 @@ class BinaryScalarOp : public UnaryOp { }); } + template + static void ComputeLogic(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + DCHECK_EQ(inputs.size(), 1); + DCHECK_EQ(outputs.size(), 1); + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + const double alpha = nnvm::get(attrs.parsed); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + }); + }); + } + template static void ComputeEx(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 22e7652a4019..b7625fccf258 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -251,6 +251,23 @@ class UnaryOp : public OpBase { }); } + template + static void ComputeLogic(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (inputs[0].Size() != 0) { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); + } + }); + }); + } + template static void ComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -408,7 +425,7 @@ struct CastParam : public dmlc::Parameter { int dtype; DMLC_DECLARE_PARAMETER(CastParam) { DMLC_DECLARE_FIELD(dtype) - MXNET_ADD_ALL_TYPES + MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Output data type."); } }; @@ -432,7 +449,7 @@ void CastCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DstDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, { Tensor out = outputs[0].FlatTo1D(s); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 117cfa96518a..7c35c44305a8 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -650,6 +650,7 @@ Example:: DMLC_REGISTER_PARAMETER(CastParam); NNVM_REGISTER_OP(Cast) .add_alias("cast") +.add_alias("_npi_cast") .add_alias("_npx_cast") .describe(R"code(Casts all elements of the input to a new type. diff --git a/src/operator/tensor/histogram.cc b/src/operator/tensor/histogram.cc index 754475bff9ad..b7896e9e0016 100644 --- a/src/operator/tensor/histogram.cc +++ b/src/operator/tensor/histogram.cc @@ -123,6 +123,7 @@ void HistogramForwardImpl(const OpContext& ctx, DMLC_REGISTER_PARAMETER(HistogramParam); NNVM_REGISTER_OP(_histogram) +.add_alias("_npi_histogram") .describe(R"code(This operators implements the histogram function. Example:: diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 237e6d57ae10..2c1bc2b2098f 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -43,6 +43,11 @@ #include #endif +#ifdef __CUDACC__ +#include "./pseudo2DTranspose_op-inl.cuh" +#endif + + namespace mxnet { namespace op { @@ -301,7 +306,6 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index } } - template void TransposeImpl(RunContext ctx, const TBlob& src, @@ -313,6 +317,17 @@ void TransposeImpl(RunContext ctx, // zero-size tensor, no need to compute if (src.shape_.Size() == 0U) return; Stream *s = ctx.get_stream(); +#ifdef __CUDACC__ + // This transpose can be used only if there exist n and m such that: + // params = (0, ..., n-1, n+m, ..., params.size, n, ..., n+m-1) + // Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3). + if (isPseudo2DTranspose(axes)) { + MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { + transpose_pseudo2D(ret, src, axes, s); + }); + return; + } +#endif MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { switch (axes.ndim()) { case 0: { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 6bf1ec0c5d5c..bd683c90aede 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -116,13 +116,9 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and - // MKLDNNsupport the data type or the shape. Then convert + // MKLDNN support the data type or the shape. Then convert // it to the output format and shape - if (SupportMKLDNNReshape(inputs[0], outputs[0])) { - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); - return; - } - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, @@ -140,66 +136,42 @@ inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(Reshape) .add_alias("reshape") .describe(R"code(Reshapes the input array. - .. note:: ``Reshape`` is deprecated, use ``reshape`` - Given an array and a shape, this function returns a copy of the array in the new shape. The shape is a tuple of integers such as (2,3,4). The size of the new shape should be same as the size of the input array. - Example:: - reshape([1,2,3,4], shape=(2,2)) = [[1,2], [3,4]] - Some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. The significance of each is explained below: - - ``0`` copy this dimension from the input to the output shape. - Example:: - - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) - - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions keeping the size of the new array same as that of the input array. At most one dimension of shape can be -1. - Example:: - - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4) - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8) - input shape = (2,3,4), shape=(-1,), output shape = (24,) - - ``-2`` copy all/remainder of the input dimensions to the output shape. - Example:: - - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) - - ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension. - Example:: - - input shape = (2,3,4), shape = (-3,4), output shape = (6,4) - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20) - input shape = (2,3,4), shape = (0,-3), output shape = (2,12) - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4) - - ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1). - Example:: - - input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4) - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) - If the argument `reverse` is set to 1, then the special values are inferred from right to left. - Example:: - - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5) - with reverse=1, output shape will be (50,4). - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -227,6 +199,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); +#if MXNET_USE_MKLDNN == 100 static void FlattenEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -234,22 +207,12 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); -#if MXNET_USE_MKLDNN == 100 - auto data_ndim = inputs[0].shape().ndim(); - if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) { - MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); - return; - } else { - // This happens if inputs are supposed to be in MKLDNN format - // but MKLDNN doesn't support the data type or the shape. We're - // forced to convert it to the default format. - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); - return; - } -#endif + // If inputs are supposed to be in MKLDNN format and + // MKLDNN support the data type or the shape. Then convert + // it to the output format and shape + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } -#if MXNET_USE_MKLDNN == 100 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -266,17 +229,12 @@ NNVM_REGISTER_OP(Flatten) .add_alias("flatten") .add_alias("_npx_batch_flatten") .describe(R"code(Flattens the input array into a 2-D array by collapsing the higher dimensions. - .. note:: `Flatten` is deprecated. Use `flatten` instead. - For an input array with shape ``(d1, d2, ..., dk)``, `flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. - Note that the behavior of this function is different from numpy.ndarray.flatten, which behaves similar to mxnet.ndarray.reshape((-1,)). - Example:: - x = [[ [1,2,3], [4,5,6], @@ -286,23 +244,19 @@ Example:: [4,5,6], [7,8,9] ]], - flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) .set_attr("FInferShape", FlattenShape) .set_attr("FInferType", ElemwiseType<1, 1>) -#if MXNET_USE_MKLDNN == 100 -.set_attr("FInferStorageType", FlattenStorageType) -#endif .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) .set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FComputeEx", FlattenEx) #if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", FlattenEx) +.set_attr("FInferStorageType", FlattenStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -351,30 +305,21 @@ inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(transpose) .describe(R"code(Permutes the dimensions of an array. - Examples:: - x = [[ 1, 2], [ 3, 4]] - transpose(x) = [[ 1., 3.], [ 2., 4.]] - x = [[[ 1., 2.], [ 3., 4.]], - [[ 5., 6.], [ 7., 8.]]] - transpose(x) = [[[ 1., 5.], [ 3., 7.]], - [[ 2., 6.], [ 4., 8.]]] - transpose(x, axes=(1,0,2)) = [[[ 1., 2.], [ 5., 6.]], - [[ 3., 4.], [ 7., 8.]]] )code" ADD_FILELINE) @@ -420,12 +365,10 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - auto data_ndim = inputs[0].shape().ndim(); - if (data_ndim <= 3 && inputs[0].dtype() == mshadow::kFloat32) { - MKLDNNExpandDimsForward(attrs, ctx, inputs[0], req[0], outputs[0]); - return; - } - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); + // If inputs are supposed to be in MKLDNN format and + // MKLDNN support the data type or the shape. Then convert + // it to the output format and shape + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs, @@ -442,19 +385,14 @@ inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(expand_dims) .add_alias("_npi_expand_dims") .describe(R"code(Inserts a new axis of size 1 into the array shape - For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)`` will return a new array with shape ``(2,1,3,4)``. - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ExpandDimShape) .set_attr("FInferType", ElemwiseType<1, 1>) -#if MXNET_USE_MKLDNN == 100 -.set_attr("FInferStorageType", ExpandDimStorageType) -#endif .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; @@ -466,8 +404,9 @@ will return a new array with shape ``(2,1,3,4)``. .set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) #if MXNET_USE_MKLDNN == 100 -.set_attr("FComputeEx", ExpandDimEx) .set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", ExpandDimEx) +.set_attr("FInferStorageType", ExpandDimStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -503,44 +442,33 @@ NNVM_REGISTER_OP(slice) MXNET_ADD_SPARSE_OP_ALIAS(slice) .add_alias("crop") .describe(R"code(Slices a region of the array. - .. note:: ``crop`` is deprecated. Use ``slice`` instead. - This function returns a sliced array between the indices given by `begin` and `end` with the corresponding `step`. - For an input array of ``shape=(d_0, d_1, ..., d_n-1)``, slice operation with ``begin=(b_0, b_1...b_m-1)``, ``end=(e_0, e_1, ..., e_m-1)``, and ``step=(s_0, s_1, ..., s_m-1)``, where m <= n, results in an array with the shape ``(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)``. - The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array starting from index ``b_k`` (inclusive) with step ``s_k`` until reaching ``e_k`` (exclusive). - If the *k*-th elements are `None` in the sequence of `begin`, `end`, and `step`, the following rule will be used to set default values. If `s_k` is `None`, set `s_k=1`. If `s_k > 0`, set `b_k=0`, `e_k=d_k`; else, set `b_k=d_k-1`, `e_k=-1`. - The storage type of ``slice`` output depends on storage types of inputs - - slice(csr) = csr - otherwise, ``slice`` generates output with default storage - .. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. - Example:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], @@ -620,23 +548,17 @@ NNVM_REGISTER_OP(_slice_assign_scalar) NNVM_REGISTER_OP(slice_axis) .describe(R"code(Slices along a given axis. - Returns an array slice along a given `axis` starting from the `begin` index to the `end` index. - Examples:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice_axis(x, axis=0, begin=1, end=3) = [[ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice_axis(x, axis=1, begin=0, end=2) = [[ 1., 2.], [ 5., 6.], [ 9., 10.]] - slice_axis(x, axis=1, begin=-3, end=-1) = [[ 2., 3.], [ 6., 7.], [ 10., 11.]] @@ -660,46 +582,31 @@ NNVM_REGISTER_OP(_backward_slice_axis) NNVM_REGISTER_OP(slice_like) .describe(R"code(Slices a region of the array like the shape of another array. - This function is similar to ``slice``, however, the `begin` are always `0`s and `end` of specific axes are inferred from the second input `shape_like`. - Given the second `shape_like` input of ``shape=(d_0, d_1, ..., d_n-1)``, a ``slice_like`` operator with default empty `axes`, it performs the following operation: - `` out = slice(input, begin=(0, 0, ..., 0), end=(d_0, d_1, ..., d_n-1))``. - When `axes` is not empty, it is used to speficy which axes are being sliced. - Given a 4-d input data, ``slice_like`` operator with ``axes=(0, 2, -1)`` will perform the following operation: - `` out = slice(input, begin=(0, 0, 0, 0), end=(d_0, None, d_2, d_3))``. - Note that it is allowed to have first and second input with different dimensions, however, you have to make sure the `axes` are specified and not exceeding the dimension limits. - For example, given `input_1` with ``shape=(2,3,4,5)`` and `input_2` with ``shape=(1,2,3)``, it is not allowed to use: - `` out = slice_like(a, b)`` because ndim of `input_1` is 4, and ndim of `input_2` is 3. - The following is allowed in this situation: - `` out = slice_like(a, b, axes=(0, 2))`` - Example:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - y = [[ 0., 0., 0.], [ 0., 0., 0.]] - slice_like(x, y) = [[ 1., 2., 3.] [ 5., 6., 7.]] slice_like(x, y, axes=(0, 1)) = [[ 1., 2., 3.] @@ -745,23 +652,15 @@ NNVM_REGISTER_OP(clip) MXNET_ADD_SPARSE_OP_ALIAS(clip) .add_alias("_npi_clip") .describe(R"code(Clips (limits) the values in an array. - Given an interval, values outside the interval are clipped to the interval edges. Clipping ``x`` between `a_min` and `a_max` would be:: - .. math:: - clip(x, a_min, a_max) = \max(\min(x, a_max), a_min)) - Example:: - x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.] - The storage type of ``clip`` output depends on storage types of inputs and the a_min, a_max \ parameter values: - - clip(default) = default - clip(row_sparse, a_min <= 0, a_max >= 0) = row_sparse - clip(csr, a_min <= 0, a_max >= 0) = csr @@ -769,7 +668,6 @@ parameter values: - clip(row_sparse, a_min > 0, a_max > 0) = default - clip(csr, a_min < 0, a_max < 0) = csr - clip(csr, a_min > 0, a_max > 0) = csr - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -823,28 +721,20 @@ NNVM_REGISTER_OP(_backward_clip) NNVM_REGISTER_OP(repeat) .add_alias("_np_repeat") .describe(R"code(Repeats elements of an array. - By default, ``repeat`` flattens the input array into 1-D and then repeats the elements:: - x = [[ 1, 2], [ 3, 4]] - repeat(x, repeats=2) = [ 1., 1., 2., 2., 3., 3., 4., 4.] - The parameter ``axis`` specifies the axis along which to perform repeat:: - repeat(x, repeats=2, axis=1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] - repeat(x, repeats=2, axis=0) = [[ 1., 2.], [ 1., 2.], [ 3., 4.], [ 3., 4.]] - repeat(x, repeats=2, axis=-1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] - )code" ADD_FILELINE) .set_num_outputs(1) .set_num_inputs(1) @@ -874,35 +764,25 @@ NNVM_REGISTER_OP(_backward_repeat) NNVM_REGISTER_OP(tile) .add_alias("_npi_tile") .describe(R"code(Repeats the whole array multiple times. - If ``reps`` has length *d*, and input array has dimension of *n*. There are three cases: - - **n=d**. Repeat *i*-th dimension of the input by ``reps[i]`` times:: - x = [[1, 2], [3, 4]] - tile(x, reps=(2,3)) = [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]] - - **n>d**. ``reps`` is promoted to length *n* by pre-pending 1's to it. Thus for an input shape ``(2,3)``, ``repos=(2,)`` is treated as ``(1,2)``:: - - tile(x, reps=(2,)) = [[ 1., 2., 1., 2.], [ 3., 4., 3., 4.]] - - **n("FInferType", ElemwiseType<1, 1>) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_squeeze"}) -.add_argument("data", "NDArray-or-Symbol[]", "data to squeeze") +.add_argument("data", "NDArray-or-Symbol", "data to squeeze") .add_arguments(SqueezeParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_squeeze) @@ -1071,22 +939,17 @@ NNVM_REGISTER_OP(depth_to_space) .describe(R"code(Rearranges(permutes) data from depth into blocks of spatial data. Similar to ONNX DepthToSpace operator: https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace. -The output is a new tensor where the values from depth dimension are moved in spatial blocks +The output is a new tensor where the values from depth dimension are moved in spatial blocks to height and width dimension. The reverse of this operation is ``space_to_depth``. - .. math:: - \begin{gather*} x \prime = reshape(x, [N, block\_size, block\_size, C / (block\_size ^ 2), H * block\_size, W * block\_size]) \\ x \prime \prime = transpose(x \prime, [0, 3, 4, 1, 5, 2]) \\ y = reshape(x \prime \prime, [N, C / (block\_size ^ 2), H * block\_size, W * block\_size]) \end{gather*} - -where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] +where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] and :math:`y` is the output tensor of layout :math:`[N, C / (block\_size ^ 2), H * block\_size, W * block\_size]` - Example:: - x = [[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], @@ -1095,7 +958,6 @@ Example:: [15, 16, 17]], [[18, 19, 20], [21, 22, 23]]]] - depth_to_space(x, 2) = [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], @@ -1122,30 +984,22 @@ Example:: NNVM_REGISTER_OP(space_to_depth) .describe(R"code(Rearranges(permutes) blocks of spatial data into depth. Similar to ONNX SpaceToDepth operator: -https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth - -The output is a new tensor where the values from height and width dimension are +https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth +The output is a new tensor where the values from height and width dimension are moved to the depth dimension. The reverse of this operation is ``depth_to_space``. - .. math:: - \begin{gather*} x \prime = reshape(x, [N, C, H / block\_size, block\_size, W / block\_size, block\_size]) \\ x \prime \prime = transpose(x \prime, [0, 3, 5, 1, 2, 4]) \\ y = reshape(x \prime \prime, [N, C * (block\_size ^ 2), H / block\_size, W / block\_size]) \end{gather*} - -where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] +where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] and :math:`y` is the output tensor of layout :math:`[N, C * (block\_size ^ 2), H / block\_size, W / block\_size]` - Example:: - x = [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]] - - space_to_depth(x, 2) = [[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], @@ -1176,9 +1030,7 @@ Example:: NNVM_REGISTER_OP(_split_v2) .add_alias("_npi_split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. - Example:: - x = [[[ 1.] [ 2.]] [[ 3.] @@ -1186,61 +1038,44 @@ Example:: [[ 5.] [ 6.]]] x.shape = (3, 2, 1) - y = split_v2(x, axis=1, indices_or_sections=2) // a list of 2 arrays with shape (3, 1, 1) y = [[[ 1.]] [[ 3.]] [[ 5.]]] - [[[ 2.]] [[ 4.]] [[ 6.]]] - y[0].shape = (3, 1, 1) - z = split_v2(x, axis=0, indices_or_sections=3) // a list of 3 arrays with shape (1, 2, 1) z = [[[ 1.] [ 2.]]] - [[[ 3.] [ 4.]]] - [[[ 5.] [ 6.]]] - z[0].shape = (1, 2, 1) - w = split_v2(x, axis=0, indices_or_sections=(1,)) // a list of 2 arrays with shape [(1, 2, 1), (2, 2, 1)] w = [[[ 1.] [ 2.]]] - [[[3.] [4.]] - [[5.] [6.]]] - w[0].shape = (1, 2, 1) w[1].shape = (2, 2, 1) - `squeeze_axis=True` removes the axis with length 1 from the shapes of the output arrays. **Note** that setting `squeeze_axis` to ``1`` removes axis with length 1 only along the `axis` which it is split. Also `squeeze_axis` can be set to true only if ``input.shape[axis] == indices_or_sections``. - Example:: - z = split_v2(x, axis=0, indices_or_sections=3, squeeze_axis=1) // a list of 3 arrays with shape (2, 1) z = [[ 1.] [ 2.]] - [[ 3.] [ 4.]] - [[ 5.] [ 6.]] z[0].shape = (2, 1) - )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh new file mode 100644 index 000000000000..5b7cf04daef4 --- /dev/null +++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file pseudo2DTranspose_op-inl.cuh + * \brief pseudo 2D transpose + * \author Dawid Tracz + */ + +#ifndef MXNET_OPERATOR_TENSOR_PSEUDO2DTRANSPOSE_OP_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_PSEUDO2DTRANSPOSE_OP_INL_CUH_ + +#include +#include +#include +#include +#include +#include "../../common/cuda_utils.h" + + +namespace mxnet { +namespace op { +namespace cuda { + + +template +__global__ void transpose_pseudo2D(DType* out, DType* inp, + const index_t m, const index_t n, + const index_t nIterY, const index_t nIterZ) { + const index_t TSR = sizeof(CType)/sizeof(DType); // TypeSizeRatio + const index_t chunked_n = n/TSR; + const index_t chunked_m = m/TSR; + + union transp_t { + CType valChunk; + DType values[TSR]; + }; + + __shared__ DType d_shm[1024*TSR*TSR]; + CType* c_shm = reinterpret_cast(d_shm); + + CType* cInp = reinterpret_cast(inp); + CType* cOut = reinterpret_cast(out); + + for (index_t iterZ = 0; iterZ < nIterZ; iterZ++) { + const index_t blockIdx_z = gridDim.z*iterZ + blockIdx.z; + for (index_t iterY = 0; iterY < nIterY; iterY++) { + const index_t blockIdx_y = gridDim.y*iterY + blockIdx.y; + + index_t offset = blockIdx_z*m*chunked_n + + blockIdx_y*blockDim.y*TSR*chunked_n + + (index_t)blockIdx.x*blockDim.x; + + if ((blockIdx.x*blockDim.x + threadIdx.x)*TSR < n + && (blockIdx_y*blockDim.y + threadIdx.y)*TSR < m) { + // read from global memory to shared + #pragma unroll + for (index_t i = 0; i < TSR; i++) { + index_t shmIdx = (TSR*threadIdx.y + i)*blockDim.x + threadIdx.x; + c_shm[shmIdx] = cInp[offset + (TSR*threadIdx.y + i)*chunked_n + threadIdx.x]; + } + __syncthreads(); + + // read from shared to registers + transp_t tmp[TSR]; + #pragma unroll + for (index_t i = 0; i < TSR; i++) { + #pragma unroll + for (int j = 0; j < TSR; j++) { + index_t shmIdx = (TSR*threadIdx.y + j)*blockDim.x*TSR + TSR*threadIdx.x + i; + tmp[i].values[j] = d_shm[shmIdx]; + } + } + __syncthreads(); + + // write back to global output + offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + blockIdx_y*blockDim.y; + #pragma unroll + for (index_t i = 0; i < TSR; i++) { + cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i].valChunk; + } + } + } + } +} + +} // namespace cuda + + +/*! + * \brief Calls proper version of kernel `transpose_pseudo2D` + * basing on chosen type sizes. + * \param dTypeSize Size of data type. + * \param cTypeSize Size of type that should be use to copy. + * \param grid Grid dimensions for the kernel. + * \param block Block dimensions for the kernel. + * \param stream Strem to run kernel. + * \param out Pointer to output memory. + * \param inp Pointer to input memory. + * \param m First of tensor dimensions. + * \param n Second of tensor dimensions. + */ +inline void call_transpose_pseudo2D(index_t dTypeSize, index_t cTypeSize, + dim3 grid, dim3 block, cudaStream_t stream, + void* out, void* inp, const index_t m, const index_t n, + const index_t nIterY, const index_t nIterZ) { + switch (dTypeSize) { + case (1): { + uint8_t* d_outPtr = reinterpret_cast(out); + uint8_t* d_inpPtr = reinterpret_cast(inp); + switch (cTypeSize) { + case (1): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (2): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (4): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (8): + // case guarded against in function getBestCopyTypeSize + LOG(FATAL) << "cuda::transpose_pseudo2D would take too much shared memory"; + default: + LOG(FATAL) << "Unsupported type combination"; + } + break; + } + case (2): { + uint16_t* d_outPtr = reinterpret_cast(out); + uint16_t* d_inpPtr = reinterpret_cast(inp); + switch (cTypeSize) { + case (2): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (4): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (8): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + default: + LOG(FATAL) << "Unsupported type combination"; + } + break; + } + case (4): { + uint32_t* d_outPtr = reinterpret_cast(out); + uint32_t* d_inpPtr = reinterpret_cast(inp); + switch (cTypeSize) { + case (4): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + case (8): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + default: + LOG(FATAL) << "Unsupported type combination"; + } + break; + } + case (8): { + uint64_t* d_outPtr = reinterpret_cast(out); + uint64_t* d_inpPtr = reinterpret_cast(inp); + switch (cTypeSize) { + case (8): + cuda::transpose_pseudo2D<<>> + (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); + break; + default: + LOG(FATAL) << "Unsupported type combination"; + } + break; + } + default: + LOG(FATAL) << "Unsupported type combination"; + } + auto cuErr = cudaPeekAtLastError(); + CHECK_EQ(cuErr, cudaSuccess) << "Transpose kernel failure: " << cudaGetErrorString(cuErr) << ". " + << "block: (" << block.x << "," << block.y << "," << block.z << ")" + << " grid: (" << grid.x << "," << grid.y << "," << grid.z << ")"; +} + + +/*! + * \brief Checks if function `transpose_pseudo2D` can be used + * to perform transpose operation with given params. + * \param params Parameters (axes) of the transpose. + */ +inline bool isPseudo2DTranspose(const TShape& params) { + index_t n_swpDims = 1; + int i=0; + while (i < params.ndim() && i == params[i]) + i++; // leading dimensions + while (i+1 < params.ndim()) { + if(params[i]+1 != params[i+1]) + n_swpDims++; + i++; + } + return n_swpDims == 2; +} + + +struct pseudo2DSizes { + index_t leadDimS; + index_t M; + index_t N; +}; + +/*! + * \brief Calculates total size of last two dimension batches + * (according to description of transpose_pseudo2D function). + * \param shape Shape of tensor to transpose. + * \param params Parameters (axes) of the transpose. + */ +inline pseudo2DSizes getPackedTransposeDimensions(const TShape& shape, + const TShape& params) { + auto ndim = params.ndim(); + pseudo2DSizes sizes; + sizes.leadDimS = 1; + int i=0; + while (i < ndim && i == params[i]) { + sizes.leadDimS *= shape[i]; + i++; + } + sizes.N = shape[params[i++]]; + while (i < ndim && params[i]-1 == params[i-1]) { + sizes.N *= shape[params[i]]; + i++; + } + sizes.M = shape[params[i++]]; + while (i < ndim && params[i]-1 == params[i-1]) { + sizes.M *= shape[params[i]]; + i++; + } + CHECK_EQ(i, ndim) << "Too many dimensions to transpose"; + return sizes; +} + + +inline int32_t getBestCopyTypeSize(index_t dTypeSize, index_t sizeM, index_t sizeN) { + index_t cTypeSize = std::max((index_t)8, dTypeSize); + while (cTypeSize > dTypeSize) { + auto tsr = cTypeSize/dTypeSize; + if (sizeM % tsr != 0 || sizeN % tsr != 0) + cTypeSize /= 2; + else + break; + } + // if the cTypeSize is 8x dTypeSize then kernel would require 64kB shared memory + if(cTypeSize == 8 && dTypeSize == 1) + cTypeSize = 4; + return cTypeSize; +} + + +inline std::pair calculateKernelParams(pseudo2DSizes sizes, const index_t TSR) { + index_t nThreadsPerBlock = 32*32/4; // value chosen empirically + index_t thdsY = 1; + index_t thdsX = 1; + while(sizes.N/TSR > thdsX && thdsX < 32) { + thdsX *= 2; + } + thdsY = nThreadsPerBlock/thdsX; + thdsY = std::min(sizes.M/TSR, thdsY); + index_t blocksY = (sizes.M/TSR-1)/thdsY + 1; + index_t blocksX = (sizes.N/TSR-1)/thdsX + 1; + + dim3 grid(blocksX, blocksY, sizes.leadDimS); + dim3 block(thdsX, thdsY); + return {grid, block}; +} + + +/*! + * \brief Transpose given tensor according to params. + * Supports only transposes that satisfy: + * Exists n and m such that: + * params = (0, ..., n-1, n+m, ..., params.size, n, ..., n+m-1) + * Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3). + * \param outBlob Tensor blob to store result. + * \param inpBlob Tensor blob with input data. + * \param params Parameters (axes) of the transpose. + * \param s Pointer to GPU stream. + */ +template +void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob, + const TShape& params, mshadow::Stream* s) { + const TShape& shape = inpBlob.shape_; + CHECK_EQ(shape.ndim(), params.ndim()); + auto ndim = params.ndim(); + + auto sizes = getPackedTransposeDimensions(shape, params); + + index_t cTypeSize = getBestCopyTypeSize(sizeof(DType), sizes.M, sizes.N); + // Type Size Ratio + const index_t TSR = cTypeSize/sizeof(DType); + CHECK_EQ(cTypeSize, sizeof(DType)*TSR); + + auto pair = calculateKernelParams(sizes, TSR); + dim3 grid = pair.first; + dim3 block = pair.second; + index_t nIterY = 1; + if (grid.y > std::numeric_limits::max()) { + nIterY = (grid.y - 1)/(std::numeric_limits::max() - 1) + 1; + grid.y = (grid.y - 1)/nIterY + 1; + } + index_t nIterZ = 1; + if (grid.z > std::numeric_limits::max()) { + nIterZ = (grid.z - 1)/(std::numeric_limits::max() - 1) + 1; + grid.z = (grid.z - 1)/nIterZ + 1; + } + + cudaStream_t stream = mshadow::Stream::GetStream(s); + call_transpose_pseudo2D(sizeof(DType), cTypeSize, grid, block, stream, + outBlob.dptr_, inpBlob.dptr_, sizes.M, sizes.N, nIterY, nIterZ); +} + +} // namespace op +} // namespace mxnet + + +#endif // MXNET_OPERATOR_TENSOR_PSEUDO2DTRANSPOSE_OP_INL_CUH_ diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 99856f770d5c..e51e220c232f 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -21,7 +21,8 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor from mxnet import gluon, nd -from tests.python.unittest.common import with_seed, teardown +from tests.python.unittest.common import with_seed, with_post_test_cleanup +from nose.tools import with_setup # dimension constants MEDIUM_X = 10000 @@ -84,20 +85,20 @@ def test_ndarray_random_randint(): @with_seed() def test_ndarray_random_exponential(): - scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y)) assert a[-1][0][0][0] >= 0 - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() def test_ndarray_random_gamma(): - alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) - beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) + beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.gamma(alpha=alpha_array, beta=beta_array, shape=(SMALL_X, SMALL_Y)) assert a[-1][0][0][0] >= 0 - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() @@ -108,50 +109,50 @@ def test_ndarray_random_multinomial(): assert a[-1] >= 0 assert a.shape == (LARGE_X,) # test for NDArray multi-dimension shape - a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y)) + a = nd.random.multinomial(probs, shape=(2, SMALL_Y)) assert a[-1][0][0] >= 0 - assert a.shape == (LARGE_X, SMALL_X, SMALL_Y) + assert a.shape == (LARGE_X, 2, SMALL_Y) # test log_likelihood output shape - a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y), get_prob=True) - assert a[-1][0][0] >= 0 - assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape + a = nd.random.multinomial(probs, shape=(2, SMALL_Y), get_prob=True) + assert a[0][0][0][0] >= 0 + assert a[0].shape == (LARGE_X, 2, SMALL_Y) and a[0].shape == a[1].shape @with_seed() def test_ndarray_random_generalized_negative_binomial(): - alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) - mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) + mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array, shape=(SMALL_X, SMALL_Y)) assert a[-1][0][0][0] >= 0 - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() def test_ndarray_random_negative_binomial(): - k_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) - p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + k_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) + p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.negative_binomial(k=k_array, p=p_array, shape=(SMALL_X, SMALL_Y)) assert a[-1][0][0][0] >= 0 - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() def test_ndarray_random_normal(): - scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) - loc_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) + loc_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.normal(loc=loc_array, scale=scale_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() def test_ndarray_random_poisson(): - lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) + lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X)) a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y)) assert a[-1][0][0][0] >= 0 - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) + assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y) @with_seed() @@ -165,7 +166,7 @@ def test_ndarray_random_randn(): @with_seed() def test_ndarray_random_shuffle(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) - a[-1] == 3 # assign 3 to entire last row + a[-1] = 3 # assign 3 to entire last row a = nd.random.shuffle(a) # slice first column from shuffled array # pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y @@ -175,7 +176,7 @@ def test_ndarray_random_shuffle(): assert len(unique_a) == 2 # only 2 unique values assert unique_a[0] == 1 # first unique value is 1 assert unique_a[1] == 3 # second unique value is 3 - assert a.shape[0] == (LARGE_X, SMALL_Y) + assert a.shape == (LARGE_X, SMALL_Y) def test_ndarray_empty(): @@ -269,6 +270,7 @@ def test_slice_assign(): def test_expand_dims(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) res = nd.expand_dims(a, axis=1) + assert a[0][0][0] == 1 assert res.shape == (a.shape[0], 1, a.shape[1]) @@ -561,7 +563,7 @@ def test_sequence_last(): # test if returns last sequence b = nd.SequenceLast(a) - assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2,SMALL_Y) tensor + assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2, SMALL_Y) tensor assert b.shape == (2, SMALL_Y) # test with sequence length @@ -600,7 +602,7 @@ def test_softmax_cross_entropy(): def test_index_copy(): x = mx.nd.zeros((LARGE_X, SMALL_Y)) t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y)) - index = mx.nd.array([LARGE_X - 1]) + index = mx.nd.array([LARGE_X - 1], dtype="int64") x = mx.nd.contrib.index_copy(x, index, t) assert x[-1][-1] == t[0][-1] @@ -637,23 +639,23 @@ def test_leaky_relu(): def test_leaky(): res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3) - assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy() + assert_almost_equal(res[-1][-1].asnumpy(), 0.3*a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3) def test_elu(): res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3) - assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1) + assert_almost_equal(res[-1][-1].asnumpy(), 0.3*(np.exp(a[-1][-1].asnumpy())-1), atol=1e-3, rtol=1e-3) def test_selu(): lam = 1.0507009873554804934193349852946 alpha = 1.6732632423543772848170429916717 res = mx.nd.LeakyReLU(a, act_type="selu") - assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)) + assert_almost_equal(res[-1][-1].asnumpy(), (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)), atol=1e-3, rtol=1e-3) def test_rrelu(): lower = 0.125 upper = 0.333999991 res = mx.nd.LeakyReLU(a, act_type="rrelu") - assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy() + assert_almost_equal(res[0][-1][-1].asnumpy(), (lower + upper) / 2 * a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3) test_leaky() test_elu() @@ -662,31 +664,31 @@ def test_rrelu(): def test_pooling(): - a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y)) + a = mx.nd.ones((MEDIUM_X, 200, SMALL_Y, SMALL_Y)) def test_avg_pooling(): res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg') - assert res[-1][-1][-1][-1] == 1.0000001 - assert res.shape == SMALL_Y - 5 + 1 + assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1.0000001, atol=1e-3, rtol=1e-3) + assert res.shape[-1] == SMALL_Y - 5 + 1 def test_max_pooling(): res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max') - assert res[-1][-1][-1][-1] == 1. - assert res.shape == SMALL_Y - 5 + 1 + assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1., atol=1e-3, rtol=1e-3) + assert res.shape[-1] == SMALL_Y - 5 + 1 def test_sum_pooling(): res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum') - assert res[-1][-1][-1][-1] == 25 - assert res.shape == SMALL_Y - 5 + 1 + assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25, atol=1e-3, rtol=1e-3) + assert res.shape[-1] == SMALL_Y - 5 + 1 def test_lp_pooling(): res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2) - assert res[-1][-1][-1][-1] == 5. - assert res.shape == SMALL_Y - 5 + 1 + assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 5., atol=1e-3, rtol=1e-3) + assert res.shape[-1] == SMALL_Y - 5 + 1 res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1) - assert res[-1][-1][-1][-1] == 25. - assert res.shape == SMALL_Y - 5 + 1 + assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25., atol=1e-3, rtol=1e-3) + assert res.shape[-1] == SMALL_Y - 5 + 1 test_avg_pooling() test_max_pooling() @@ -741,36 +743,37 @@ def test_dropout(): exe = y.simple_bind(ctx=default_context(), data=shape) exe.arg_arrays[0][:] = 1 out = exe.forward(is_train=True) - assert out.shape == out.shape + nd.waitall() + assert out[0].shape == shape def test_activation(): - a = mx.nd.ones((LARGE_X, SMALL_Y)) + x = mx.nd.ones((LARGE_X, SMALL_Y)) test_x = -2 - a[-1, -1] = test_x + x[-1, -1] = test_x # Hyperbolic tangent (tanh) # y = (exp(x)-exp(-x))/(exp(x)+exp(-x)) - a = mx.nd.Activation(a, act_type="tanh") - tanh_x = (np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x)) - assert a[-1][-1] == tanh_x + y = mx.nd.Activation(x, act_type="tanh") + tanh_x = ((np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x))) + assert y[-1][-1] == np.float32(tanh_x) # Recitified Linear Unit (relu) # y = max(x,0) - a = mx.nd.Activation(a, act_type="relu") - assert a[-1][-1] == 0 + y = mx.nd.Activation(x, act_type="relu") + assert y[-1][-1] == 0 # Sigmoid # y = x/(1+abs(x)) - a = mx.nd.Activation(a, act_type="sigmoid") - sigmoid_x = 1/(1+math.exp(-test_x)) - assert a[-1][-1] == sigmoid_x + y = mx.nd.Activation(x, act_type="sigmoid") + sigmoid_x = (1/(1+math.exp(-test_x))) + assert_almost_equal(y[-1][-1].asnumpy(), np.float32(sigmoid_x), atol=1e-3, rtol=1e-3) # Soft Sign # y = 1/(1+exp(-x)) - a = mx.nd.Activation(a, act_type="softsign") - softsign_x = test_x/(1+abs(test_x)) - assert a[-1][-1] == softsign_x + y = mx.nd.Activation(x, act_type="softsign") + softsign_x = (test_x/(1+abs(test_x))) + assert y[-1][-1] == np.float32(softsign_x) # TODO: correctness of batchnorm @@ -924,8 +927,7 @@ def test_copy_to(): b = nd.array(np.zeros((SMALL_Y, LARGE_X))) c = a.copyto(b) assert c is b - print(b) - assert b[0][-1] == LARGE_X-1 + assert b[-1][-1] == SMALL_Y-1 def test_zeros_like(): @@ -957,24 +959,17 @@ def test_flatten(): assert b.shape == (LARGE_X//2, SMALL_Y*2) -def test_expand_dims(): - a = nd.array(np.ones((SMALL_Y, LARGE_X))) - b = nd.expand_dims(a, axis=1) - nd.waitall() - assert b.shape == (SMALL_Y, 1, LARGE_X) - - def test_concat(): a = nd.array(np.ones((SMALL_Y, LARGE_X))) b = nd.array(np.zeros((SMALL_Y, LARGE_X))) - c = nd.concat(a,b, dim=0) + c = nd.concat(a, b, dim=0) assert c.shape == (b.shape[0]*2, LARGE_X) def test_stack(): a = nd.array(np.ones((SMALL_Y, LARGE_X))) b = nd.array(np.zeros((SMALL_Y, LARGE_X))) - c = nd.stack(a,b, axis=1) + c = nd.stack(a, b, axis=1) assert c.shape == (b.shape[0], 2, LARGE_X) @@ -1019,7 +1014,7 @@ def test_max(): def test_norm(): a = np.array(np.full((1, LARGE_X), 3)) b = np.array(np.full((1, LARGE_X), 4)) - c = nd.array(np.concatenate((a,b), axis=0)) + c = nd.array(np.concatenate((a, b), axis=0)) d = nd.norm(c, ord=2, axis=0) e = nd.norm(c, ord=1, axis=0) assert d.shape[0] == LARGE_X @@ -1031,7 +1026,7 @@ def test_norm(): def test_argmax(): a = np.ones((SMALL_Y, LARGE_X)) b = np.zeros((SMALL_Y, LARGE_X)) - c = nd.array(np.concatenate((a,b), axis=0)) + c = nd.array(np.concatenate((a, b), axis=0)) d = nd.argmax(c, axis=0) assert d.shape[0] == LARGE_X assert d[-1] == d[0] == 0 @@ -1040,12 +1035,13 @@ def test_argmax(): def test_relu(): def frelu(x): return np.maximum(x, 0.0) + def frelu_grad(x): return 1.0 * (x > 0.0) shape = (SMALL_Y, LARGE_X) x = mx.symbol.Variable("x") y = mx.sym.relu(x) - xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + xa = np.random.uniform(low=-1.0, high=1.0, size=shape) eps = 1e-4 xa[abs(xa) < eps] = 1.0 ya = frelu(xa) @@ -1059,7 +1055,7 @@ def fsigmoid(a): shape = (SMALL_Y, LARGE_X) x = mx.symbol.Variable("x") y = mx.sym.sigmoid(x) - xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + xa = np.random.uniform(low=-1.0, high=1.0, size=shape) ya = fsigmoid(xa) check_symbolic_forward(y, [xa], [ya]) @@ -1116,15 +1112,6 @@ def test_idiv(): assert c[0][-1] == 2 -def test_imod(): - a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) - b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) - c = a - c %= b - assert c.shape == a.shape - assert c[0][-1] == 1 - - def test_eq(): a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) @@ -1198,7 +1185,7 @@ def test_slice_axis(): def test_one_hot(): - #default dtype of ndarray is float32 which cannot index elements over 2^32 + # default dtype of ndarray is float32 which cannot index elements over 2^32 a = nd.array([1, (VLARGE_X - 1)], dtype=np.int64) b = nd.one_hot(a, VLARGE_X) b[0][1] == 1 diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 169f5244d784..aa6cb3d75b37 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -21,7 +21,8 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector from mxnet import gluon, nd -from tests.python.unittest.common import with_seed, teardown +from tests.python.unittest.common import with_seed +from nose.tools import with_setup # dimension constants LARGE_X = 4300000000 @@ -168,7 +169,7 @@ def test_topk(): val = nd.topk(a, k=1, axis=0, dtype=np.int64, ret_typ="value") assert val == (LARGE_X - 1) - + def test_mean(): a = nd.arange(-LARGE_X // 2, LARGE_X // 2 + 1, dtype=np.int64) b = nd.mean(a, axis=0) @@ -505,14 +506,14 @@ def test_rpow(): def test_shape(): b = create_vector(size=LARGE_X) - #explicit wait_to_read() + # explicit wait_to_read() assert b[0] == 0 assert b.shape[0] == LARGE_X def test_size(): b = create_vector(size=LARGE_X) - #explicit wait_to_read() + # explicit wait_to_read() assert b[0] == 0 assert b.size == LARGE_X @@ -552,7 +553,7 @@ def test_ones_like(): def test_concat(): a = nd.ones(LARGE_X) b = nd.zeros(LARGE_X) - c = nd.concat(a,b, dim=0) + c = nd.concat(a, b, dim=0) assert c[0][0] == 1 assert c[-1][-1] == 0 assert c.shape[0] == (2 * LARGE_X) @@ -635,15 +636,6 @@ def test_idiv(): assert c[-1] == 2 -def test_imod(): - a = nd.full(LARGE_X, 3) - b = nd.full(LARGE_X, 2) - c = a - c %= b - assert c.shape == a.shape - assert c[0][-1] == 1 - - def test_eq(): a = nd.full(LARGE_X, 3) b = nd.full(LARGE_X, 3) diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 7cd637da3d4f..06fb16288649 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -272,6 +272,27 @@ def teardown(): mx.nd.waitall() +def with_post_test_cleanup(): + """ + Helper function that cleans up memory by releasing it from memory pool + Required especially by large tensor tests that have memory footprints in GBs. + """ + def test_helper(orig_test): + @make_decorator(orig_test) + def test_new(*args, **kwargs): + logger = default_logger() + try: + orig_test(*args, **kwargs) + except: + logger.info(test_msg) + raise + finally: + mx.nd.waitall() + mx.cpu().empty_cache() + return test_new + return test_helper + + def run_in_spawned_process(func, env, *args): """ Helper function to run a test in its own process. diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 20b964c96a30..53a8076f1303 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -26,7 +26,7 @@ from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np from common import with_seed, TemporaryDirectory -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops, assert_exception +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable from mxnet.ndarray.ndarray import py_slice from mxnet.base import integer_types import scipy.stats as ss @@ -264,9 +264,10 @@ def test_np_ndarray_binary_element_wise_ops(): '/': _np.divide, 'mod': _np.mod, 'pow': _np.power, + } - if has_tvm_ops(): + if is_op_runnable(): np_op_map.update({ '==': _np.equal, '!=': _np.not_equal, @@ -468,23 +469,38 @@ def test_np_grad_ndarray_type(): @with_seed() +@use_np def test_np_ndarray_astype(): mx_data = np.array([2, 3, 4, 5], dtype=_np.int32) np_data = mx_data.asnumpy() - def check_astype_equal(dtype, copy, expect_zero_copy=False): - mx_ret = mx_data.astype(dtype=dtype, copy=copy) + class TestAstype(HybridBlock): + def __init__(self, dtype, copy): + super(TestAstype, self).__init__() + self._dtype = dtype + self._copy = copy + + def hybrid_forward(self, F, x): + return x.astype(dtype=self._dtype, copy=self._copy) + + def check_astype_equal(dtype, copy, expect_zero_copy=False, hybridize=False): + test_astype = TestAstype(dtype, copy) + if hybridize: + test_astype.hybridize() + mx_ret = test_astype(mx_data) assert type(mx_ret) is np.ndarray np_ret = np_data.astype(dtype=dtype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) - if expect_zero_copy: + if expect_zero_copy and not hybridize: assert id(mx_ret) == id(mx_data) assert id(np_ret) == id(np_data) - for dtype in [_np.int8, _np.uint8, _np.int32, _np.float16, _np.float32, _np.float64]: + for dtype in [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, + 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool']: for copy in [True, False]: - check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype) + for hybridize in [True, False]: + check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype, hybridize) @with_seed() @@ -978,7 +994,8 @@ def test_np_multinomial(): @with_seed() -@unittest.skipUnless(has_tvm_ops(), "Comparison ops are implemented using TVM") +@unittest.skipUnless(is_op_runnable(), "Comparison ops can only run on either CPU instances, or GPU instances with" + " compute capability >= 53 if MXNet is built with USE_TVM_OP=ON") @use_np def test_np_ndarray_boolean_indexing(): def test_single_bool_index(): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7068bf20d897..0848b0861a76 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -228,11 +228,11 @@ def __init__(self): def hybrid_forward(self, F, x1, x2): return F.np.ldexp(x1, x2) - + def _np_ldexp(x1, x2): return x1 * _np.power(2.0, x2) - def dldx(x1, x2): + def dldx(x1, x2): grad_a = _np.power(2.0, x2) grad_b = _np_ldexp(x1, x2) * _np.log(2.0) if len(x1) == 1: @@ -241,7 +241,7 @@ def dldx(x1, x2): grad_b = _np.sum(grad_b) return [grad_a, grad_b] - shapes = [ + shapes = [ ((3, 1), (3, 1)), ((3, 1, 2), (3, 1, 2)), ((1, ),(1, )), @@ -250,7 +250,7 @@ def dldx(x1, x2): ((3, 0), (3, 0)), # zero-size shape ((0, 1), (0, 1)), # zero-size shape ((2, 0, 2), (2, 0, 2)), # zero-size shape - ] + ] for hybridize in [True, False]: for shape1, shape2 in shapes: @@ -258,7 +258,7 @@ def dldx(x1, x2): test_ldexp = TestLdexp() if hybridize: test_ldexp.hybridize() - x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray() + x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray() x1.attach_grad() x2 = rand_ndarray(shape=shape2, dtype=dtype).as_np_ndarray() x2.attach_grad() @@ -997,13 +997,13 @@ def __init__(self, axis): self._axis = axis def hybrid_forward(self, F, x): - return F.np.squeeze(x, axis=self._axis) + return F.np.squeeze(x, self._axis) for shape, axis in config: data_np = _np.random.uniform(size=shape) data_mx = np.array(data_np, dtype=data_np.dtype) - ret_np = _np.squeeze(data_np, axis=axis) - ret_mx = np.squeeze(data_mx, axis=axis) + ret_np = _np.squeeze(data_np, axis) + ret_mx = np.squeeze(data_mx, axis) assert_almost_equal(ret_mx.asnumpy(), ret_np, rtol=1e-5, atol=1e-6, use_broadcast=False) net = TestSqueeze(axis) @@ -1163,6 +1163,7 @@ def hybrid_forward(self, F, a): axeses.append(tuple(axes)) random.shuffle(axes) axeses.append(tuple(axes)) + axeses.append([i - len(axes) for i in axes]) for axes in axeses: test_trans = TestTranspose(axes) if hybridize: @@ -1179,10 +1180,15 @@ def hybrid_forward(self, F, a): np_backward = np_transpose_grad(np_out.shape, dtype, axes) assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - mx_out = np.transpose(x, axes) - np_out = _np.transpose(x.asnumpy(), axes) + mx_out = x.transpose(axes) + np_out = x.asnumpy().transpose(axes) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + if isinstance(axes, (list, tuple)): + mx_out = x.transpose(*axes) + np_out = x.asnumpy().transpose(*axes) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + @with_seed() @use_np @@ -1342,6 +1348,8 @@ def hybrid_forward(self, F, a, *args, **kwargs): y = mx_func(mx_test_data) assert y.shape == np_out.shape assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + if np_out.dtype == np.bool_: + assert y.dtype == np.bool_ if ref_grad: y.backward() @@ -2039,6 +2047,23 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_histogram(): + shapes = [(), (3, 4), (3, 0)] + + for shape in shapes: + mx_a = np.random.uniform(0.0, 10.0, size=shape) + np_a = mx_a.asnumpy() + mx_bins = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5., 6., 7., 8., 9., 10.]) + np_bins = mx_bins.asnumpy() + for bins, _range in [(20, (0.0, 10.0)), (mx_bins, None)]: + mx_cnts, mx_bins = np.histogram(mx_a, bins=bins, range=_range) + np_cnts, np_bins = _np.histogram(np_a, bins=bins if isinstance(bins, mx.base.numeric_types) else bins.asnumpy(), range=_range) + assert_almost_equal(mx_cnts.asnumpy(), np_cnts, rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_bins.asnumpy(), np_bins, rtol=1e-3, atol=1e-5) + + @with_seed() @use_np def test_np_choice(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 713b11ead48b..41e48242b2cb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2911,6 +2911,45 @@ def test_transpose(): @with_seed() +def test_pseudo2dtranspose(): + def getTwoInts(mn, mx): + n1 = np.random.randint(mn, mx) + n2 = np.random.randint(mn, mx-1) + n2 = n2 if n2 < n1 else n2+1 + return tuple(np.sort([n1, n2])) + + def getTranspAxes(ndim): + axes = list(range(ndim)) + n1, n2 = getTwoInts(0,ndim) + return tuple(axes[:n1]+axes[n2:]+axes[n1:n2]) + + for ndim in range(2, 7): + for dt in ['int8', 'half', 'int32', 'int64']: + for _ in range(5): + dims = list(np.random.randint(5, 20, size=ndim)) + axes = getTranspAxes(ndim) + x = mx.nd.array(np.random.normal(size=dims), dtype=dt) + y = mx.nd.transpose(x, axes=axes) + assert_allclose(np.transpose(x.asnumpy(), axes=axes), y.asnumpy()) + + +@with_seed() +def test_big_transpose(): + n = [1] + d = list(np.random.randint(132, 160, size=1)) + hw = list(np.random.randint(256, 320, size=2)) + c = [10] + dims = n + d + hw + c + axes = (0,4,1,2,3) + x_np = np.random.normal(size=dims).astype('uint8') + x = mx.nd.array(x_np, dtype='uint8') + y = mx.nd.transpose(x, axes=axes) + assert_allclose(np.transpose(x_np, axes=axes), y.asnumpy().astype('uint8')) + axes = (0,2,3,4,1) + z = mx.nd.transpose(y, axes=axes) + assert_allclose(x_np, z.asnumpy().astype('uint8')) + + def test_larger_transpose(): x = mx.nd.random.normal(shape=(50,51)) y = mx.nd.transpose(x)