diff --git a/CMakeLists.txt b/CMakeLists.txt index 5045bba9d989..70b09917de13 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,11 @@ message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}") message(STATUS "CMAKE_SYSTEM_PROCESSOR ${CMAKE_SYSTEM_PROCESSOR}") message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}") + +if(USE_TVM_OP) + add_definitions(-DMXNET_USE_TVM_OP=1) +endif() + if(USE_CUDA AND NOT USE_OLDCMAKECUDA) message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'") if( @@ -743,7 +748,6 @@ if(USE_DIST_KVSTORE) endif() if(USE_TVM_OP) - add_definitions(-DMXNET_USE_TVM_OP=1) list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so) include(cmake/BuildTVM.cmake) add_subdirectory("3rdparty/tvm") diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 6e2f5fa15919..8ad5247fa263 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -653,3 +653,54 @@ def _np_trace(a, offset=0, axis1=0, axis2=1, out=None): (2, 3) """ pass + + +def _np_squeeze(a, axis=None, out=None): + """ + Remove single-dimensional entries from the shape of an array. + + Parameters + ---------- + a : ndarray + Input data. + axis : None or int or tuple of ints, optional + Selects a subset of the single-dimensional entries in the + shape. If an axis is selected with shape entry greater than + one, an error is raised. + out : ndarray, optional + Array into which the output is placed. It must have the same size + and dtype as the input array. + + Returns + ------- + squeezed : ndarray + The input array, but with all or a subset of the + dimensions of length 1 removed. It always returns a copy of `a`. + + Raises + ------ + MXNetError + If `axis` is not `None`, and an axis being squeezed is not of length 1 + + See Also + -------- + expand_dims : The inverse operation, adding singleton dimensions + reshape : Insert, remove, and combine dimensions, and resize existing ones + + Examples + -------- + >>> x = np.array([[[0], [1], [2]]]) + >>> x.shape + (1, 3, 1) + >>> np.squeeze(x).shape + (3,) + >>> np.squeeze(x, axis=0).shape + (3, 1) + >>> np.squeeze(x, axis=1).shape + Traceback (most recent call last): + ... + mxnet.base.MXNetError: cannot select an axis to squeeze out which has size=3 not equal to one + >>> np.squeeze(x, axis=2).shape + (1, 3) + """ + pass diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index aea8b19c2913..e3cf09703533 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -2110,11 +2110,11 @@ def logical_not(x, out=None, **kwargs): -------- >>> x= np.array([True, False, 0, 1]) >>> np.logical_not(x) - array([0., 1., 1., 0.]) + array([False, True, True, False]) >>> x = np.arange(5) >>> np.logical_not(x<3) - array([0., 0., 0., 1., 1.]) + array([False, False, False, True, True]) """ return _unary_func_helper(x, _npi.logical_not, _np.logical_not, out=out, **kwargs) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index d3ae4d19aca8..5ee52f14bb16 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -302,7 +302,7 @@ def __getitem__(self, key): except Exception as err: raise TypeError('{}'.format(str(err))) if isinstance(key, _np.ndarray) and key.dtype == _np.bool_: - key = array(key, dtype='bool') + key = array(key, dtype='bool', ctx=self.ctx) if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing key_shape = key.shape key_ndim = len(key_shape) @@ -364,6 +364,8 @@ def __setitem__(self, key, value): """ if isinstance(value, NDArray) and not isinstance(value, ndarray): raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray') + + # handle basic and advanced indexing if self.ndim == 0: if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') @@ -753,7 +755,7 @@ def detach(self): check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return _np_ndarray_cls(hdl) - def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument + def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument """ Copy of the array, cast to a specified type. @@ -1237,7 +1239,14 @@ def tile(self, *args, **kwargs): def transpose(self, *axes): # pylint: disable=arguments-differ """Permute the dimensions of an array.""" - return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) + if len(axes) == 0: + axes = None + elif len(axes) == 1: + if isinstance(axes[0], (tuple, list)): + axes = axes[0] + elif axes[0] is None: + axes = None + return _mx_np_op.transpose(self, axes=axes) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -3401,11 +3410,11 @@ def logical_not(x, out=None, **kwargs): -------- >>> x= np.array([True, False, 0, 1]) >>> np.logical_not(x) - array([0., 1., 1., 0.]) + array([False, True, True, False]) >>> x = np.arange(5) >>> np.logical_not(x<3) - array([0., 0., 0., 1., 1.]) + array([False, False, False, True, True]) """ return _mx_nd_np.logical_not(x, out=out, **kwargs) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 5a19d3dc4e2f..879ab4d56a46 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -25,7 +25,7 @@ from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import -from ..util import is_np_shape, is_np_array, set_np, reset_np +from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import from . import random # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 9a909420e934..338c70d1e53d 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -23,7 +23,7 @@ import numpy as _np from . import _op as _mx_np_op from ...base import _LIB, SymbolHandle, numeric_types, mx_uint -from ...util import check_call, set_module +from ...util import check_call, set_module, _sanity_check_params from ...context import current_context from ..symbol import Symbol from .._internal import _set_np_symbol_class @@ -181,8 +181,29 @@ def T(self): return self.transpose() # pylint: enable= invalid-name, undefined-variable - def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ - raise NotImplementedError + def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument + """ + Copy of the array, cast to a specified type. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + copy : bool, optional + Default `True`. By default, astype always returns a newly + allocated ndarray on the same context. If this is set to + `False`, and the dtype requested is the same as the ndarray's + dtype, the ndarray is returned instead of a copy. + + Returns + ------- + arr_t : ndarray + Unless `copy` is False and the other conditions for returning the input + array are satisfied (see description for `copy` input parameter), `arr_t` + is a new array of the same shape as the input array with `dtype`. + """ + _sanity_check_params('astype', ['order', 'casting', 'subok'], kwargs) + return _npi.cast(self, dtype=dtype) def dot(self, b, out=None): """Dot product of two arrays. @@ -438,7 +459,14 @@ def transpose(self, *axes): # pylint: disable=arguments-differ """The arguments are the same as for :py:func:`transpose`, with this array as data. """ - return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) + if len(axes) == 0: + axes = None + elif len(axes) == 1: + if isinstance(axes[0], (tuple, list)): + axes = axes[0] + elif axes[0] is None: + axes = None + return _mx_np_op.transpose(self, axes=axes) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c5b3b0d8e4cd..e5f2c6c02089 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -51,6 +51,7 @@ from .symbol.numpy import _Symbol as np_symbol from .util import use_np # pylint: disable=unused-import from .runtime import Features +from .numpy_extension import get_cuda_compute_capability def default_context(): @@ -2235,10 +2236,34 @@ def has_tvm_ops(): """Returns True if MXNet is compiled with TVM generated operators. If current ctx is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported.""" built_with_tvm_op = _features.is_enabled("TVM_OP") - if current_context().device_type == 'gpu': + ctx = current_context() + if ctx.device_type == 'gpu': try: - import tvm - except ImportError: + cc = get_cuda_compute_capability(ctx) + except: # pylint: disable=bare-except + print('Failed to get CUDA compute capability for context {}. The operators ' + 'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx)) return False - return built_with_tvm_op and (int("".join(tvm.nd.gpu(0).compute_version.split('.'))) >= 53) + print('Cuda arch compute capability: sm_{}'.format(str(cc))) + return built_with_tvm_op and cc >= 53 return built_with_tvm_op + + +def is_op_runnable(): + """Returns True for all CPU tests. Returns True for GPU tests that are either of the following. + 1. Built with USE_TVM_OP=0. + 2. Built with USE_TVM_OP=1, but with compute capability >= 53.""" + ctx = current_context() + if ctx.device_type == 'gpu': + if not _features.is_enabled("TVM_OP"): + return True + else: + try: + cc = get_cuda_compute_capability(ctx) + except: # pylint: disable=bare-except + print('Failed to get CUDA compute capability for context {}. The operators ' + 'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx)) + return False + print('Cuda arch compute capability: sm_{}'.format(str(cc))) + return cc >= 53 + return True diff --git a/python/mxnet/util.py b/python/mxnet/util.py index d4e95e0c0c9c..1050fb2a481d 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -602,3 +602,64 @@ def set_np(shape=True, array=True): def reset_np(): """Deactivate NumPy shape and array semantics at the same time.""" set_np(shape=False, array=False) + + +_CUDA_SUCCESS = 0 + + +def get_cuda_compute_capability(ctx): + """Returns the cuda compute capability of the input `ctx`. + + Parameters + ---------- + ctx : Context + GPU context whose corresponding cuda compute capability is to be retrieved. + + Returns + ------- + cuda_compute_capability : int + CUDA compute capability. For example, it returns 70 for CUDA arch equal to `sm_70`. + + References + ---------- + https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549#file-cuda_check-py + """ + if ctx.device_type != 'gpu': + raise ValueError('Expecting a gpu context to get cuda compute capability, ' + 'while received ctx {}'.format(str(ctx))) + + libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + for libname in libnames: + try: + cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + ' '.join(libnames)) + + # Some constants taken from cuda.h + + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + device = ctypes.c_int() + error_str = ctypes.c_char_p() + + ret = cuda.cuInit(0) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuInit failed with erro code {}: {}' + .format(ret, error_str.value.decode())) + + ret = cuda.cuDeviceGet(ctypes.byref(device), ctx.device_id) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuDeviceGet failed with error code {}: {}' + .format(ret, error_str.value.decode())) + ret = cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) + if ret != _CUDA_SUCCESS: + cuda.cuGetErrorString(ret, ctypes.byref(error_str)) + raise RuntimeError('cuDeviceComputeCapability failed with error code {}: {}' + .format(ret, error_str.value.decode())) + return cc_major.value * 10 + cc_minor.value diff --git a/src/common/utils.h b/src/common/utils.h index 2bd6aac6f5d9..fbecc8b4e955 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -823,6 +823,21 @@ static inline std::string GetOutputName(const nnvm::NodeEntry& e) { return sym.ListOutputNames()[0]; } +inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& src) { + // convert negative axes to positive values + const int ndim = src.ndim(); + mxnet::TShape axes = src; + for (int i = 0; i < ndim; ++i) { + if (axes[i] < 0) { + axes[i] += ndim; + } + CHECK(axes[i] >= 0 && axes[i] < ndim) << "axes[" << i << "]=" + << axes[i] << " exceeds the range [" + << 0 << ", " << ndim << ")"; + } + return axes; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 653dec84563a..34429446bd62 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -36,23 +36,14 @@ template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { - if (from.type_flag_ == mshadow::kBool || to->type_flag_ == mshadow::kBool) { - CHECK_EQ(from.type_flag_, to->type_flag_) << "Only supports copying data between" - " two boolean tensors."; - const index_t size = from.Size(); - CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(bool) - << " bytes, to: " << to->Size() * sizeof(bool) << " bytes."; - common::ParallelCopy(to->dptr(), from.dptr(), size); - return; - } - MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { const index_t size = static_cast(from.Size()); CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType) << " bytes, to: " << to->Size() * sizeof(DType) << " bytes."; common::ParallelCopy(to->dptr(), from.dptr(), size); } else { - MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D() = mshadow::expr::tcast(from.FlatTo1D()); }) diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 6439c417bfe3..2a1461cc8c48 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -76,10 +76,6 @@ void Copy(const TBlob &from, TBlob *to, from.FlatTo1D(s), s); } else { - CHECK_NE(from.type_flag_, mshadow::kBool) - << "Copying boolean ndarray across devices is not supported"; - CHECK_NE(to->type_flag_, mshadow::kBool) - << "Copying boolean ndarray across devices is not supported"; MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 2c4127b9a088..d73fa1be54a4 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -134,7 +134,7 @@ class LeakyReLUOp : public Operator { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape, in_data[leakyrelu::kData].dptr(), in_data[leakyrelu::kGamma].dptr(), diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index a941f9b25a54..471b6f395a05 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -73,6 +73,14 @@ using std::is_integral; } \ } +#define MXNET_UNARY_LOGIC_OP_NC(name, expr) \ + struct name : public mxnet_op::tunable { \ + template \ + MSHADOW_XINLINE static bool Map(DType a) { \ + return (expr); \ + } \ + } + #define MXNET_BINARY_MATH_OP(name, expr) \ struct name : public mxnet_op::tunable { \ template \ @@ -89,6 +97,14 @@ using std::is_integral; } \ } +#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ + struct name : public mxnet_op::tunable { \ + template \ + MSHADOW_XINLINE static bool Map(DType a, DType b) { \ + return (expr); \ + } \ + } + #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a)) #define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b)) @@ -335,6 +351,8 @@ MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b))); MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1)); +MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast(a)); + MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0)); @@ -347,6 +365,18 @@ MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0)); +MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false); + +MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false); + MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0)); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index b46ce8a598d9..950db174595e 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -703,6 +703,26 @@ struct op_with_req { const DType *input_3) { KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], input_3[i])); } + + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *in) { + KERNEL_ASSIGN(out[i], req, OP::Map(in[i])); + } + + /*! \brief inputs are two tensors with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *lhs, const DType *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief input is tensor and two scalar value with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const DType *in, const DType value) { + KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); + } }; template diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index cb1cd428bd5e..eda9051fd0a2 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -394,7 +394,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, @@ -463,7 +463,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, grad.dptr(), mask.dptr(), gdata.dptr()); diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc new file mode 100644 index 000000000000..7e8951afa1d0 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_logic_op.cc + * \brief CPU Implementation of basic logic functions for elementwise numpy binary + * broadcast operator. + */ + +#if MXNET_USE_TVM_OP +#include +#include +#include "../tvmop/op_module.h" +#endif // MXNET_USE_TVM_OP + +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +static constexpr char func_equal_cpu[] = "equal_cpu"; +static constexpr char func_equal_gpu[] = "equal_gpu"; +static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; +static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; +static constexpr char func_greater_cpu[] = "greater_cpu"; +static constexpr char func_greater_gpu[] = "greater_gpu"; +static constexpr char func_less_cpu[] = "less_cpu"; +static constexpr char func_less_gpu[] = "less_gpu"; +static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; +static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; +static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; +static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; + +bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; + TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +TBlob PrependAxes(const TBlob& src, const int dst_ndim) { + CHECK_LE(src.shape_.ndim(), dst_ndim); + const int src_ndim = src.shape_.ndim(); + if (src_ndim == dst_ndim) return src; + mxnet::TShape dst_shape(dst_ndim, 1); + for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { + dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; + } + return src.reshape(dst_shape); +} + +struct TVMBinaryBroadcastCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; + std::vector type_codes; + std::vector values; + + const int ondim = outputs[0].shape_.ndim(); + const size_t num_args = inputs.size() + outputs.size(); + type_codes.resize(num_args); + values.resize(num_args); + for (size_t i = 0; i < num_args; ++i) { + tblobs[i] = PrependAxes(tblobs[i], ondim); + type_codes[i] = kArrayHandle; + values[i].v_handle = const_cast(&(tblobs[i].dltensor())); + } + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); + +#if MXNET_USE_TVM_OP + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) + +#if MXNET_USE_CUDA + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_CUDA + +#else + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", BinaryBroadcastComputeLogic) + +#endif // MXNET_USE_TVM_OP + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(less_equal); + +bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +struct TVMBinaryBroadcastScalarCompute { + const char* func; + void operator()(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_TVM_OP + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor + + // prepare tblobs and TVMArgs + std::vector tblobs = {inputs[0], outputs[0]}; + std::vector type_codes; + std::vector values; + + const size_t num_args = 3; // one input tensor, one scalar param, and one output + type_codes.resize(num_args); + values.resize(num_args); + + // input tensor setup + type_codes[0] = kArrayHandle; + values[0].v_handle = const_cast(&(tblobs[0].dltensor())); + + // scalar param + type_codes[1] = kDLFloat; + values[1].v_float64 = nnvm::get(attrs.parsed); + + // output tensor + type_codes[2] = kArrayHandle; + values[2].v_handle = const_cast(&(tblobs[1].dltensor())); + + tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); + tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); +#else + LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " + "to enable TVM-generated kernels for operator " << func; +#endif // MXNET_USE_TVM_OP + } +}; + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"data"}; \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FGradient", MakeZeroGradNodes) \ + .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("scalar", "float", "scalar input") + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal); + +static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; +static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; +static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; +static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; +static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; +static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; +static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; +static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; +static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; +static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; +static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; +static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; + +#if MXNET_USE_TVM_OP + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_scalar_cpu}) + +#if MXNET_USE_CUDA + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_scalar_gpu}) + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_CUDA + +#else + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) + +#endif // MXNET_USE_TVM_OP + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(less_equal); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cu b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu new file mode 100644 index 000000000000..98995a39dcb9 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_logic_op.cu + * \brief GPU Implementation of basic functions for elementwise binary + * broadcast logic operator. + */ +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + + +#if MXNET_USE_TVM_OP == 0 + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name) \ + .set_attr("FCompute", BinaryBroadcastComputeLogic) + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) + +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); + +#endif // MXNET_USE_TVM_OP + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index aa81a58c2890..7e07e47b4a8f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,12 +23,6 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#if MXNET_USE_TVM_OP -#include -#include -#include "../tvmop/op_module.h" -#endif // MXNET_USE_TVM_OP - #include "../tensor/elemwise_binary_broadcast_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -302,223 +296,6 @@ NNVM_REGISTER_OP(_backward_npi_hypot) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); -static constexpr char func_equal_cpu[] = "equal_cpu"; -static constexpr char func_equal_gpu[] = "equal_gpu"; -static constexpr char func_not_equal_cpu[] = "not_equal_cpu"; -static constexpr char func_not_equal_gpu[] = "not_equal_gpu"; -static constexpr char func_greater_cpu[] = "greater_cpu"; -static constexpr char func_greater_gpu[] = "greater_gpu"; -static constexpr char func_less_cpu[] = "less_cpu"; -static constexpr char func_less_gpu[] = "less_gpu"; -static constexpr char func_greater_equal_cpu[] = "greater_equal_cpu"; -static constexpr char func_greater_equal_gpu[] = "greater_equal_gpu"; -static constexpr char func_less_equal_cpu[] = "less_equal_cpu"; -static constexpr char func_less_equal_gpu[] = "less_equal_gpu"; - -bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; - TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); - return true; -} - -TBlob PrependAxes(const TBlob& src, const int dst_ndim) { - CHECK_LE(src.shape_.ndim(), dst_ndim); - const int src_ndim = src.shape_.ndim(); - if (src_ndim == dst_ndim) return src; - mxnet::TShape dst_shape(dst_ndim, 1); - for (int i = dst_ndim - src_ndim; i < dst_ndim; ++i) { - dst_shape[i] = src.shape_[i - dst_ndim + src_ndim]; - } - return src.reshape(dst_shape); -} - -struct TVMBinaryBroadcastCompute { - const char* func; - void operator()(const nnvm::NodeAttrs& attrs, - const mxnet::OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if MXNET_USE_TVM_OP - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor - - // prepare tblobs and TVMArgs - std::vector tblobs = {inputs[0], inputs[1], outputs[0]}; - std::vector type_codes; - std::vector values; - - const int ondim = outputs[0].shape_.ndim(); - const size_t num_args = inputs.size() + outputs.size(); - type_codes.resize(num_args); - values.resize(num_args); - for (size_t i = 0; i < num_args; ++i) { - tblobs[i] = PrependAxes(tblobs[i], ondim); - type_codes[i] = kArrayHandle; - values[i].v_handle = const_cast(&(tblobs[i].dltensor())); - } - tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], tblobs.size()); - tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); -#else - LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " - "to enable TVM-generated kernels for operator " << func; -#endif // MXNET_USE_TVM_OP - } -}; - -#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_num_inputs(2) \ - .set_num_outputs(1) \ - .set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{"lhs", "rhs"}; \ - }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ - .set_attr("FInferType", NumpyBinaryLogicOpType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs) { \ - return std::vector >{{0, 0}, {1, 0}}; \ - }) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) \ - .set_attr("FGradient", MakeZeroGradNodes) \ - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") - -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(not_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(greater_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(less_equal); - -#if MXNET_USE_CUDA -#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) - -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(greater_equal); -MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(less_equal); -#endif // MXNET_USE_CUDA - -bool NumpyBinaryScalarLogicOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - if (in_attrs->at(0) == -1) return false; - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); - return true; -} - -struct TVMBinaryBroadcastScalarCompute { - const char* func; - void operator()(const nnvm::NodeAttrs& attrs, - const mxnet::OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if MXNET_USE_TVM_OP - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].shape_.Size() == 0U) return; // skip zero-size tensor - - // prepare tblobs and TVMArgs - std::vector tblobs = {inputs[0], outputs[0]}; - std::vector type_codes; - std::vector values; - - const size_t num_args = 3; // one input tensor, one scalar param, and one output - type_codes.resize(num_args); - values.resize(num_args); - - // input tensor setup - type_codes[0] = kArrayHandle; - values[0].v_handle = const_cast(&(tblobs[0].dltensor())); - - // scalar param - type_codes[1] = kDLFloat; - values[1].v_float64 = nnvm::get(attrs.parsed); - - // output tensor - type_codes[2] = kArrayHandle; - values[2].v_handle = const_cast(&(tblobs[1].dltensor())); - - tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 3); - tvm::runtime::TVMOpModule::Get()->CallEx(func, ctx, tblobs, tvm_args); -#else - LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag for compiling MXNet source code " - "to enable TVM-generated kernels for operator " << func; -#endif // MXNET_USE_TVM_OP - } -}; - -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{"data"}; \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarLogicOpType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs) { \ - return std::vector >{{0, 0}}; \ - }) \ - .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_cpu}) \ - .set_attr("FGradient", MakeZeroGradNodes) \ - .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") - -static constexpr char func_equal_scalar_cpu[] = "equal_scalar_cpu"; -static constexpr char func_equal_scalar_gpu[] = "equal_scalar_gpu"; -static constexpr char func_not_equal_scalar_cpu[] = "not_equal_scalar_cpu"; -static constexpr char func_not_equal_scalar_gpu[] = "not_equal_scalar_gpu"; -static constexpr char func_greater_scalar_cpu[] = "greater_scalar_cpu"; -static constexpr char func_greater_scalar_gpu[] = "greater_scalar_gpu"; -static constexpr char func_less_scalar_cpu[] = "less_scalar_cpu"; -static constexpr char func_less_scalar_gpu[] = "less_scalar_gpu"; -static constexpr char func_greater_equal_scalar_cpu[] = "greater_equal_scalar_cpu"; -static constexpr char func_greater_equal_scalar_gpu[] = "greater_equal_scalar_gpu"; -static constexpr char func_less_equal_scalar_cpu[] = "less_equal_scalar_cpu"; -static constexpr char func_less_equal_scalar_gpu[] = "less_equal_scalar_gpu"; - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(greater_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(less_equal_scalar); - -#if MXNET_USE_CUDA -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ - NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastScalarCompute{func_##name##_gpu}) - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(not_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(greater_equal_scalar); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal_scalar); -#endif // MXNET_USE_CUDA - MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index b454d89212e2..c980dcfaab5d 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -65,22 +65,49 @@ NNVM_REGISTER_OP(_np_copy) }) .add_argument("a", "NDArray-or-Symbol", "The input"); -#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY(__name$, __input_name$, __kernel$) \ -NNVM_REGISTER_OP(__name$) \ -.set_num_inputs(1) \ -.set_num_outputs(1) \ -.set_attr("FInferShape", ElemwiseShape<1, 1>) \ -.set_attr("FInferType", ElemwiseType<1, 1>) \ -.set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ -.set_attr("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{__input_name$}; \ - }) \ -.set_attr("FCompute", UnaryOp::Compute) \ -.add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY(__name$, __input_name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{__input_name$}; \ + }) \ + .set_attr("FCompute", UnaryOp::Compute) \ + .add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") + +bool NumpyUnaryLogicOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (in_attrs->at(0) == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return true; +} + +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_LOGIC(__name$, __input_name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyUnaryLogicOpType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{__input_name$}; \ + }) \ + .set_attr("FCompute", UnaryOp::ComputeLogic) \ + .add_argument(__input_name$, "NDArray-or-Symbol", "The input array.") // negative MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_negative, "x", mshadow_op::negation) @@ -243,11 +270,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_expm1, "x", mshadow_op::expm1) // logical_not -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_logical_not, "x", mshadow_op::nt) -.describe(R"code(Compute the truth value of NOT x element-wise. -Example:: - logical_not([-2., 0., 1.]) = [0., 1., 0.] -)code") +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_LOGIC(_npi_logical_not, "x", mshadow_op::np_logical_not) .set_attr("FGradient", MakeZeroGradNodes); // sin diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index 7f68386560f3..44743ed94be8 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -35,9 +35,9 @@ NNVM_REGISTER_OP(_npx_sigmoid) NNVM_REGISTER_OP(_np_copy) .set_attr("FCompute", UnaryOp::IdentityCompute); -#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ -NNVM_REGISTER_OP(__name$) \ -.set_attr("FCompute", UnaryOp::Compute) \ +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_attr("FCompute", UnaryOp::Compute) MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_negative, mshadow_op::negation); @@ -76,7 +76,8 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log1p, mshadow_op::log1p); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_expm1, mshadow_op::expm1); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_logical_not, mshadow_op::nt); +NNVM_REGISTER_OP(_npi_logical_not) +.set_attr("FCompute", UnaryOp::ComputeLogic); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sin, mshadow_op::sin); diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 5e25192d9298..0eefeb9cdb5d 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -29,6 +29,7 @@ #include #include "../tensor/matrix_op-inl.h" #include "../nn/concat-inl.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -59,7 +60,8 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, const NumpyTransposeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; if (ndim_is_known(param.axes)) { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], param.axes); + mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } else { mxnet::TShape axes(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index f54f325bcd17..83f0b1aae9b0 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -40,21 +40,53 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; + mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - mxnet::TShape ret(shp.ndim(), -1); + + int ndim = -1; + if (ndim_is_known(shp)) { + ndim = shp.ndim(); + } else if (ndim_is_known(out_shp)) { + ndim = out_shp.ndim(); + } + if (ndim < 0) { + return false; + } + if (out_shp.ndim() >= 0 && shp.ndim() >= 0) { + CHECK_EQ(out_shp.ndim(), shp.ndim()); + } + + mxnet::TShape get(ndim, -1); + mxnet::TShape ret(ndim, -1); + if (ndim_is_known(param.axes)) { - CHECK_EQ(shp.ndim(), param.axes.ndim()); - for (int i = 0; i < shp.ndim(); ++i) { - CHECK(param.axes[i] < static_cast(shp.ndim())); - ret[i] = shp[param.axes[i]]; + CHECK_EQ(ndim, param.axes.ndim()); + mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + if (ndim_is_known(shp)) { + for (int i = 0; i < ndim; ++i) { + ret[i] = shp[axes[i]]; + } + } + if (ndim_is_known(out_shp)) { + for (int i = 0; i < ndim; ++i) { + get[axes[i]] = out_shp[i]; + } } } else { - for (int i = 0; i < shp.ndim(); ++i) { - ret[i] = shp[shp.ndim()-1-i]; + if (ndim_is_known(shp)) { + for (int i = 0; i < ndim; ++i) { + ret[i] = shp[ndim - 1 - i]; + } + } + if (ndim_is_known(out_shp)) { + for (int i = 0; i < ndim; ++i) { + get[ndim - 1 - i] = out_shp[i]; + } } } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, get); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); - return shape_is_known(ret); + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); } NNVM_REGISTER_OP(_np_transpose) @@ -69,7 +101,12 @@ NNVM_REGISTER_OP(_np_transpose) if (ndim_is_known(param.axes)) { mxnet::TShape axes = mxnet::TShape(param.axes.ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { - axes[param.axes[i]] = i; + int axis = param.axes[i]; + if (axis < 0) { + axis += param.axes.ndim(); + } + CHECK(axis >= 0 && axis < param.axes.ndim()); + axes[axis] = i; } std::ostringstream os; os << axes; diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 49d0e23bdd19..398d5a714bd1 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -209,6 +209,9 @@ struct static_init_var { #define IMPLEMENT_BINARY_WORKLOAD_FWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$) +#define IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(__op$) \ + MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$) + #define IMPLEMENT_BINARY_WORKLOAD_BWD(__op$) \ MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_BWD, __op$) @@ -307,6 +310,7 @@ IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::nt); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT() @@ -378,6 +382,12 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_not_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_greater); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_greater_equal); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_less); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_less_equal); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 29c476d06d1c..6a612e6f1cd5 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -186,12 +186,12 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } namespace mxnet_op { -template +template struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, - const Shape &oshape, DType *lhs, DType *rhs, + const Shape &oshape, IType *lhs, IType *rhs, DType *out) { Shape coord = unravel(base, oshape); auto lidx = static_cast(dot(coord, lstride)); @@ -209,7 +209,7 @@ struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, - const Shape &oshape, DType lhs, DType *rhs, + const Shape &oshape, IType lhs, IType *rhs, DType *out) { Shape coord = unravel(base, oshape); auto lidx = static_cast(dot(coord, lstride)); @@ -306,7 +306,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); @@ -315,6 +315,36 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } } +template +void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] != kNullOp) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + outputs[0].dptr()); + }); + }); + } + } +} + template void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx, const NDArray& csr, @@ -413,11 +443,11 @@ void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx, Shape lstride = calc_stride(new_csrshape.get()); Shape rstride = calc_stride(new_dnsshape.get()); if (reverse && std::is_same::value) { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } else { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 9c1d8b17fdea..6f444aed21fe 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -497,6 +497,32 @@ class ElemwiseBinaryOp : public OpBase { } } + template + static void ComputeLogic(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] != kNullOp) { + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), + inputs[1].dptr()); + } + }); + }); + } + } + template static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index c78841641214..02b005eed995 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -244,6 +244,26 @@ class BinaryScalarOp : public UnaryOp { }); } + template + static void ComputeLogic(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + DCHECK_EQ(inputs.size(), 1); + DCHECK_EQ(outputs.size(), 1); + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + const double alpha = nnvm::get(attrs.parsed); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + }); + }); + } + template static void ComputeEx(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 22e7652a4019..b7625fccf258 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -251,6 +251,23 @@ class UnaryOp : public OpBase { }); } + template + static void ComputeLogic(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (inputs[0].Size() != 0) { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); + } + }); + }); + } + template static void ComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -408,7 +425,7 @@ struct CastParam : public dmlc::Parameter { int dtype; DMLC_DECLARE_PARAMETER(CastParam) { DMLC_DECLARE_FIELD(dtype) - MXNET_ADD_ALL_TYPES + MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Output data type."); } }; @@ -432,7 +449,7 @@ void CastCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DstDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, { Tensor out = outputs[0].FlatTo1D(s); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 117cfa96518a..7c35c44305a8 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -650,6 +650,7 @@ Example:: DMLC_REGISTER_PARAMETER(CastParam); NNVM_REGISTER_OP(Cast) .add_alias("cast") +.add_alias("_npi_cast") .add_alias("_npx_cast") .describe(R"code(Casts all elements of the input to a new type. diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 45df2ed53ded..58d9eaefa04b 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -306,7 +306,6 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index } } - template void TransposeImpl(RunContext ctx, const TBlob& src, diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 20b964c96a30..53a8076f1303 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -26,7 +26,7 @@ from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np from common import with_seed, TemporaryDirectory -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops, assert_exception +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable from mxnet.ndarray.ndarray import py_slice from mxnet.base import integer_types import scipy.stats as ss @@ -264,9 +264,10 @@ def test_np_ndarray_binary_element_wise_ops(): '/': _np.divide, 'mod': _np.mod, 'pow': _np.power, + } - if has_tvm_ops(): + if is_op_runnable(): np_op_map.update({ '==': _np.equal, '!=': _np.not_equal, @@ -468,23 +469,38 @@ def test_np_grad_ndarray_type(): @with_seed() +@use_np def test_np_ndarray_astype(): mx_data = np.array([2, 3, 4, 5], dtype=_np.int32) np_data = mx_data.asnumpy() - def check_astype_equal(dtype, copy, expect_zero_copy=False): - mx_ret = mx_data.astype(dtype=dtype, copy=copy) + class TestAstype(HybridBlock): + def __init__(self, dtype, copy): + super(TestAstype, self).__init__() + self._dtype = dtype + self._copy = copy + + def hybrid_forward(self, F, x): + return x.astype(dtype=self._dtype, copy=self._copy) + + def check_astype_equal(dtype, copy, expect_zero_copy=False, hybridize=False): + test_astype = TestAstype(dtype, copy) + if hybridize: + test_astype.hybridize() + mx_ret = test_astype(mx_data) assert type(mx_ret) is np.ndarray np_ret = np_data.astype(dtype=dtype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) - if expect_zero_copy: + if expect_zero_copy and not hybridize: assert id(mx_ret) == id(mx_data) assert id(np_ret) == id(np_data) - for dtype in [_np.int8, _np.uint8, _np.int32, _np.float16, _np.float32, _np.float64]: + for dtype in [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, + 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool']: for copy in [True, False]: - check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype) + for hybridize in [True, False]: + check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype, hybridize) @with_seed() @@ -978,7 +994,8 @@ def test_np_multinomial(): @with_seed() -@unittest.skipUnless(has_tvm_ops(), "Comparison ops are implemented using TVM") +@unittest.skipUnless(is_op_runnable(), "Comparison ops can only run on either CPU instances, or GPU instances with" + " compute capability >= 53 if MXNet is built with USE_TVM_OP=ON") @use_np def test_np_ndarray_boolean_indexing(): def test_single_bool_index(): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 89fe576d0a0e..0848b0861a76 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1163,6 +1163,7 @@ def hybrid_forward(self, F, a): axeses.append(tuple(axes)) random.shuffle(axes) axeses.append(tuple(axes)) + axeses.append([i - len(axes) for i in axes]) for axes in axeses: test_trans = TestTranspose(axes) if hybridize: @@ -1179,10 +1180,15 @@ def hybrid_forward(self, F, a): np_backward = np_transpose_grad(np_out.shape, dtype, axes) assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - mx_out = np.transpose(x, axes) - np_out = _np.transpose(x.asnumpy(), axes) + mx_out = x.transpose(axes) + np_out = x.asnumpy().transpose(axes) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + if isinstance(axes, (list, tuple)): + mx_out = x.transpose(*axes) + np_out = x.asnumpy().transpose(*axes) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + @with_seed() @use_np @@ -1342,6 +1348,8 @@ def hybrid_forward(self, F, a, *args, **kwargs): y = mx_func(mx_test_data) assert y.shape == np_out.shape assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + if np_out.dtype == np.bool_: + assert y.dtype == np.bool_ if ref_grad: y.backward()