Skip to content

Commit

Permalink
Comparison ops implemented using mshadow (apache#16414)
Browse files Browse the repository at this point in the history
* boolean op without tvm working

* Fix astype for boolean arrays

* More tests

* Revert

* Fix

* Fix preprocessor in .cu

* Fix logical_not

* Print compilation flags

* Fix transpose taking negative indices

* Fix transpose negative axes

* Fix transpose

* Fix

* Try to fix USE_TVM_OP not understood in .cu

* Fix squeeze

* Finally

* Fix

* Try to fix invalid ptx

* Implement API to get cuda compute capability

* Fix test_utils.py

* Fix pylint
  • Loading branch information
reminisce authored and aaronmarkham committed Oct 16, 2019
1 parent 84b34d2 commit a3cdbac
Show file tree
Hide file tree
Showing 31 changed files with 871 additions and 312 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR ${CMAKE_SYSTEM_PROCESSOR}")

message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")

if(USE_TVM_OP)
add_definitions(-DMXNET_USE_TVM_OP=1)
endif()

if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'")
if(
Expand Down Expand Up @@ -743,7 +748,6 @@ if(USE_DIST_KVSTORE)
endif()

if(USE_TVM_OP)
add_definitions(-DMXNET_USE_TVM_OP=1)
list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
include(cmake/BuildTVM.cmake)
add_subdirectory("3rdparty/tvm")
Expand Down
51 changes: 51 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,54 @@ def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
(2, 3)
"""
pass


def _np_squeeze(a, axis=None, out=None):
"""
Remove single-dimensional entries from the shape of an array.
Parameters
----------
a : ndarray
Input data.
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the
shape. If an axis is selected with shape entry greater than
one, an error is raised.
out : ndarray, optional
Array into which the output is placed. It must have the same size
and dtype as the input array.
Returns
-------
squeezed : ndarray
The input array, but with all or a subset of the
dimensions of length 1 removed. It always returns a copy of `a`.
Raises
------
MXNetError
If `axis` is not `None`, and an axis being squeezed is not of length 1
See Also
--------
expand_dims : The inverse operation, adding singleton dimensions
reshape : Insert, remove, and combine dimensions, and resize existing ones
Examples
--------
>>> x = np.array([[[0], [1], [2]]])
>>> x.shape
(1, 3, 1)
>>> np.squeeze(x).shape
(3,)
>>> np.squeeze(x, axis=0).shape
(3, 1)
>>> np.squeeze(x, axis=1).shape
Traceback (most recent call last):
...
mxnet.base.MXNetError: cannot select an axis to squeeze out which has size=3 not equal to one
>>> np.squeeze(x, axis=2).shape
(1, 3)
"""
pass
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,11 +2110,11 @@ def logical_not(x, out=None, **kwargs):
--------
>>> x= np.array([True, False, 0, 1])
>>> np.logical_not(x)
array([0., 1., 1., 0.])
array([False, True, True, False])
>>> x = np.arange(5)
>>> np.logical_not(x<3)
array([0., 0., 0., 1., 1.])
array([False, False, False, True, True])
"""
return _unary_func_helper(x, _npi.logical_not, _np.logical_not, out=out, **kwargs)

Expand Down
19 changes: 14 additions & 5 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __getitem__(self, key):
except Exception as err:
raise TypeError('{}'.format(str(err)))
if isinstance(key, _np.ndarray) and key.dtype == _np.bool_:
key = array(key, dtype='bool')
key = array(key, dtype='bool', ctx=self.ctx)
if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing
key_shape = key.shape
key_ndim = len(key_shape)
Expand Down Expand Up @@ -364,6 +364,8 @@ def __setitem__(self, key, value):
"""
if isinstance(value, NDArray) and not isinstance(value, ndarray):
raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray')

# handle basic and advanced indexing
if self.ndim == 0:
if not isinstance(key, tuple) or len(key) != 0:
raise IndexError('scalar tensor can only accept `()` as index')
Expand Down Expand Up @@ -753,7 +755,7 @@ def detach(self):
check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
return _np_ndarray_cls(hdl)

def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument
def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument
"""
Copy of the array, cast to a specified type.
Expand Down Expand Up @@ -1237,7 +1239,14 @@ def tile(self, *args, **kwargs):

def transpose(self, *axes): # pylint: disable=arguments-differ
"""Permute the dimensions of an array."""
return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
if len(axes) == 0:
axes = None
elif len(axes) == 1:
if isinstance(axes[0], (tuple, list)):
axes = axes[0]
elif axes[0] is None:
axes = None
return _mx_np_op.transpose(self, axes=axes)

def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
Expand Down Expand Up @@ -3401,11 +3410,11 @@ def logical_not(x, out=None, **kwargs):
--------
>>> x= np.array([True, False, 0, 1])
>>> np.logical_not(x)
array([0., 1., 1., 0.])
array([False, True, True, False])
>>> x = np.arange(5)
>>> np.logical_not(x<3)
array([0., 0., 0., 1., 1.])
array([False, False, False, True, True])
"""
return _mx_nd_np.logical_not(x, out=out, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import _register
from ._op import * # pylint: disable=wildcard-import
from ..context import * # pylint: disable=wildcard-import
from ..util import is_np_shape, is_np_array, set_np, reset_np
from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability
from ..ndarray import waitall
from .utils import * # pylint: disable=wildcard-import
from . import random # pylint: disable=wildcard-import
Expand Down
36 changes: 32 additions & 4 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as _np
from . import _op as _mx_np_op
from ...base import _LIB, SymbolHandle, numeric_types, mx_uint
from ...util import check_call, set_module
from ...util import check_call, set_module, _sanity_check_params
from ...context import current_context
from ..symbol import Symbol
from .._internal import _set_np_symbol_class
Expand Down Expand Up @@ -181,8 +181,29 @@ def T(self):
return self.transpose()
# pylint: enable= invalid-name, undefined-variable

def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ
raise NotImplementedError
def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument
"""
Copy of the array, cast to a specified type.
Parameters
----------
dtype : str or dtype
Typecode or data-type to which the array is cast.
copy : bool, optional
Default `True`. By default, astype always returns a newly
allocated ndarray on the same context. If this is set to
`False`, and the dtype requested is the same as the ndarray's
dtype, the ndarray is returned instead of a copy.
Returns
-------
arr_t : ndarray
Unless `copy` is False and the other conditions for returning the input
array are satisfied (see description for `copy` input parameter), `arr_t`
is a new array of the same shape as the input array with `dtype`.
"""
_sanity_check_params('astype', ['order', 'casting', 'subok'], kwargs)
return _npi.cast(self, dtype=dtype)

def dot(self, b, out=None):
"""Dot product of two arrays.
Expand Down Expand Up @@ -438,7 +459,14 @@ def transpose(self, *axes): # pylint: disable=arguments-differ
"""The arguments are the same as for :py:func:`transpose`, with
this array as data.
"""
return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
if len(axes) == 0:
axes = None
elif len(axes) == 1:
if isinstance(axes[0], (tuple, list)):
axes = axes[0]
elif axes[0] is None:
axes = None
return _mx_np_op.transpose(self, axes=axes)

def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
Expand Down
33 changes: 29 additions & 4 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability


def default_context():
Expand Down Expand Up @@ -2235,10 +2236,34 @@ def has_tvm_ops():
"""Returns True if MXNet is compiled with TVM generated operators. If current ctx
is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported."""
built_with_tvm_op = _features.is_enabled("TVM_OP")
if current_context().device_type == 'gpu':
ctx = current_context()
if ctx.device_type == 'gpu':
try:
import tvm
except ImportError:
cc = get_cuda_compute_capability(ctx)
except: # pylint: disable=bare-except
print('Failed to get CUDA compute capability for context {}. The operators '
'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx))
return False
return built_with_tvm_op and (int("".join(tvm.nd.gpu(0).compute_version.split('.'))) >= 53)
print('Cuda arch compute capability: sm_{}'.format(str(cc)))
return built_with_tvm_op and cc >= 53
return built_with_tvm_op


def is_op_runnable():
"""Returns True for all CPU tests. Returns True for GPU tests that are either of the following.
1. Built with USE_TVM_OP=0.
2. Built with USE_TVM_OP=1, but with compute capability >= 53."""
ctx = current_context()
if ctx.device_type == 'gpu':
if not _features.is_enabled("TVM_OP"):
return True
else:
try:
cc = get_cuda_compute_capability(ctx)
except: # pylint: disable=bare-except
print('Failed to get CUDA compute capability for context {}. The operators '
'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx))
return False
print('Cuda arch compute capability: sm_{}'.format(str(cc)))
return cc >= 53
return True
61 changes: 61 additions & 0 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,64 @@ def set_np(shape=True, array=True):
def reset_np():
"""Deactivate NumPy shape and array semantics at the same time."""
set_np(shape=False, array=False)


_CUDA_SUCCESS = 0


def get_cuda_compute_capability(ctx):
"""Returns the cuda compute capability of the input `ctx`.
Parameters
----------
ctx : Context
GPU context whose corresponding cuda compute capability is to be retrieved.
Returns
-------
cuda_compute_capability : int
CUDA compute capability. For example, it returns 70 for CUDA arch equal to `sm_70`.
References
----------
https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549#file-cuda_check-py
"""
if ctx.device_type != 'gpu':
raise ValueError('Expecting a gpu context to get cuda compute capability, '
'while received ctx {}'.format(str(ctx)))

libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
except OSError:
continue
else:
break
else:
raise OSError("could not load any of: " + ' '.join(libnames))

# Some constants taken from cuda.h

cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
device = ctypes.c_int()
error_str = ctypes.c_char_p()

ret = cuda.cuInit(0)
if ret != _CUDA_SUCCESS:
cuda.cuGetErrorString(ret, ctypes.byref(error_str))
raise RuntimeError('cuInit failed with erro code {}: {}'
.format(ret, error_str.value.decode()))

ret = cuda.cuDeviceGet(ctypes.byref(device), ctx.device_id)
if ret != _CUDA_SUCCESS:
cuda.cuGetErrorString(ret, ctypes.byref(error_str))
raise RuntimeError('cuDeviceGet failed with error code {}: {}'
.format(ret, error_str.value.decode()))
ret = cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device)
if ret != _CUDA_SUCCESS:
cuda.cuGetErrorString(ret, ctypes.byref(error_str))
raise RuntimeError('cuDeviceComputeCapability failed with error code {}: {}'
.format(ret, error_str.value.decode()))
return cc_major.value * 10 + cc_minor.value
15 changes: 15 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,21 @@ static inline std::string GetOutputName(const nnvm::NodeEntry& e) {
return sym.ListOutputNames()[0];
}

inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& src) {
// convert negative axes to positive values
const int ndim = src.ndim();
mxnet::TShape axes = src;
for (int i = 0; i < ndim; ++i) {
if (axes[i] < 0) {
axes[i] += ndim;
}
CHECK(axes[i] >= 0 && axes[i] < ndim) << "axes[" << i << "]="
<< axes[i] << " exceeds the range ["
<< 0 << ", " << ndim << ")";
}
return axes;
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
13 changes: 2 additions & 11 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,14 @@ template<>
void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
if (from.type_flag_ == mshadow::kBool || to->type_flag_ == mshadow::kBool) {
CHECK_EQ(from.type_flag_, to->type_flag_) << "Only supports copying data between"
" two boolean tensors.";
const index_t size = from.Size();
CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(bool)
<< " bytes, to: " << to->Size() * sizeof(bool) << " bytes.";
common::ParallelCopy(to->dptr<bool>(), from.dptr<bool>(), size);
return;
}
MSHADOW_TYPE_SWITCH(to->type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, {
if (to->type_flag_ == from.type_flag_) {
const index_t size = static_cast<index_t>(from.Size());
CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType)
<< " bytes, to: " << to->Size() * sizeof(DType) << " bytes.";
common::ParallelCopy(to->dptr<DType>(), from.dptr<DType>(), size);
} else {
MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, {
to->FlatTo1D<cpu, DType>() =
mshadow::expr::tcast<DType>(from.FlatTo1D<cpu, SrcDType>());
})
Expand Down
4 changes: 0 additions & 4 deletions src/ndarray/ndarray_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ void Copy<gpu, gpu>(const TBlob &from, TBlob *to,
from.FlatTo1D<gpu, DType>(s),
s);
} else {
CHECK_NE(from.type_flag_, mshadow::kBool)
<< "Copying boolean ndarray across devices is not supported";
CHECK_NE(to->type_flag_, mshadow::kBool)
<< "Copying boolean ndarray across devices is not supported";
MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, {
to->FlatTo1D<gpu, DType>(s) =
mshadow::expr::tcast<DType>(from.FlatTo1D<gpu, SrcDType>(s));
Expand Down
2 changes: 1 addition & 1 deletion src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class LeakyReLUOp : public Operator {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::xelu>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape,
in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>(),
Expand Down
Loading

0 comments on commit a3cdbac

Please sign in to comment.