Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/master' into mkldnn-v1.0
Browse files Browse the repository at this point in the history
Conflicts:
	src/operator/nn/mkldnn/mkldnn_base-inl.h
	src/operator/nn/mkldnn/mkldnn_flatten-inl.h
	src/operator/nn/mkldnn/mkldnn_flatten.cc
	src/operator/nn/mkldnn/mkldnn_ops-inl.h
	src/operator/nn/mkldnn/mkldnn_reshape-inl.h
	src/operator/nn/mkldnn/mkldnn_reshape.cc
	src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc
	src/operator/tensor/matrix_op.cc
  • Loading branch information
TaoLv committed Oct 15, 2019
2 parents 9f77575 + 6d6e46b commit 43e35a9
Show file tree
Hide file tree
Showing 51 changed files with 1,667 additions and 976 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR ${CMAKE_SYSTEM_PROCESSOR}")

message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")

if(USE_TVM_OP)
add_definitions(-DMXNET_USE_TVM_OP=1)
endif()

if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'")
if(
Expand Down Expand Up @@ -739,7 +744,6 @@ if(USE_DIST_KVSTORE)
endif()

if(USE_TVM_OP)
add_definitions(-DMXNET_USE_TVM_OP=1)
list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
include(cmake/BuildTVM.cmake)
add_subdirectory("3rdparty/tvm")
Expand Down
2 changes: 1 addition & 1 deletion ci/docker_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DOCKERHUB_LOGIN_NUM_RETRIES = 5
DOCKERHUB_RETRY_SECONDS = 5
DOCKER_CACHE_NUM_RETRIES = 3
DOCKER_CACHE_TIMEOUT_MINS = 15
DOCKER_CACHE_TIMEOUT_MINS = 45
PARALLEL_BUILDS = 10


Expand Down
2 changes: 1 addition & 1 deletion docs/static_site/src/pages/api/faq/add_op_in_backend.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
layout: page_category
title: Exception Handling in MXNet
title: A Beginner's Guide to Implementing Operators in MXNet Backend
category: faq
faq_c: Extend and Contribute to MXNet
question: How do I implement operators in MXNet backend?
Expand Down
51 changes: 51 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,54 @@ def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
(2, 3)
"""
pass


def _np_squeeze(a, axis=None, out=None):
"""
Remove single-dimensional entries from the shape of an array.
Parameters
----------
a : ndarray
Input data.
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the
shape. If an axis is selected with shape entry greater than
one, an error is raised.
out : ndarray, optional
Array into which the output is placed. It must have the same size
and dtype as the input array.
Returns
-------
squeezed : ndarray
The input array, but with all or a subset of the
dimensions of length 1 removed. It always returns a copy of `a`.
Raises
------
MXNetError
If `axis` is not `None`, and an axis being squeezed is not of length 1
See Also
--------
expand_dims : The inverse operation, adding singleton dimensions
reshape : Insert, remove, and combine dimensions, and resize existing ones
Examples
--------
>>> x = np.array([[[0], [1], [2]]])
>>> x.shape
(1, 3, 1)
>>> np.squeeze(x).shape
(3,)
>>> np.squeeze(x, axis=0).shape
(3, 1)
>>> np.squeeze(x, axis=1).shape
Traceback (most recent call last):
...
mxnet.base.MXNetError: cannot select an axis to squeeze out which has size=3 not equal to one
>>> np.squeeze(x, axis=2).shape
(1, 3)
"""
pass
53 changes: 50 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -741,6 +741,53 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint: disable=too-many-arguments
"""
Compute the histogram of a set of data.
Parameters
----------
a : ndarray
Input data. The histogram is computed over the flattened array.
bins : int or NDArray
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
.. versionadded:: 1.11.0
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float)
The lower and upper range of the bins. Required when `bins` is an integer.
Values outside the range are ignored. The first element of the range must
be less than or equal to the second.
normed : bool, optional
Not supported yet, coming soon.
weights : array_like, optional
Not supported yet, coming soon.
density : bool, optional
Not supported yet, coming soon.
"""
if normed is True:
raise NotImplementedError("normed is not supported yet...")
if weights is not None:
raise NotImplementedError("weights is not supported yet...")
if density is True:
raise NotImplementedError("density is not supported yet...")
if isinstance(bins, numeric_types):
if range is None:
raise NotImplementedError("automatic range is not supported yet...")
return _npi.histogram(a, bin_cnt=bins, range=range)
if isinstance(bins, (list, tuple)):
raise NotImplementedError("array_like bins is not supported yet...")
if isinstance(bins, str):
raise NotImplementedError("string bins is not supported yet...")
if isinstance(bins, NDArray):
return _npi.histogram(a, bins=bins)
raise ValueError("np.histogram fails with", locals())


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down Expand Up @@ -2063,11 +2110,11 @@ def logical_not(x, out=None, **kwargs):
--------
>>> x= np.array([True, False, 0, 1])
>>> np.logical_not(x)
array([0., 1., 1., 0.])
array([False, True, True, False])
>>> x = np.arange(5)
>>> np.logical_not(x<3)
array([0., 0., 0., 1., 1.])
array([False, False, False, True, True])
"""
return _unary_func_helper(x, _npi.logical_not, _np.logical_not, out=out, **kwargs)

Expand Down
53 changes: 47 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']


# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
_NDARRAY_BASIC_INDEXING = 0
Expand Down Expand Up @@ -301,7 +302,7 @@ def __getitem__(self, key):
except Exception as err:
raise TypeError('{}'.format(str(err)))
if isinstance(key, _np.ndarray) and key.dtype == _np.bool_:
key = array(key, dtype='bool')
key = array(key, dtype='bool', ctx=self.ctx)
if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing
key_shape = key.shape
key_ndim = len(key_shape)
Expand Down Expand Up @@ -363,6 +364,8 @@ def __setitem__(self, key, value):
"""
if isinstance(value, NDArray) and not isinstance(value, ndarray):
raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray')

# handle basic and advanced indexing
if self.ndim == 0:
if not isinstance(key, tuple) or len(key) != 0:
raise IndexError('scalar tensor can only accept `()` as index')
Expand Down Expand Up @@ -752,7 +755,7 @@ def detach(self):
check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
return _np_ndarray_cls(hdl)

def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument
def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-argument
"""
Copy of the array, cast to a specified type.
Expand Down Expand Up @@ -1236,7 +1239,14 @@ def tile(self, *args, **kwargs):

def transpose(self, *axes): # pylint: disable=arguments-differ
"""Permute the dimensions of an array."""
return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
if len(axes) == 0:
axes = None
elif len(axes) == 1:
if isinstance(axes[0], (tuple, list)):
axes = axes[0]
elif axes[0] is None:
axes = None
return _mx_np_op.transpose(self, axes=axes)

def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
Expand Down Expand Up @@ -3400,11 +3410,11 @@ def logical_not(x, out=None, **kwargs):
--------
>>> x= np.array([True, False, 0, 1])
>>> np.logical_not(x)
array([0., 1., 1., 0.])
array([False, True, True, False])
>>> x = np.arange(5)
>>> np.logical_not(x<3)
array([0., 0., 0., 1., 1.])
array([False, False, False, True, True])
"""
return _mx_nd_np.logical_not(x, out=out, **kwargs)

Expand Down Expand Up @@ -3604,6 +3614,37 @@ def tensordot(a, b, axes=2):
return _mx_nd_np.tensordot(a, b, axes)


@set_module('mxnet.numpy')
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint-disable=too-many-arguments
"""
Compute the histogram of a set of data.
Parameters
----------
a : ndarray
Input data. The histogram is computed over the flattened array.
bins : int or NDArray
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
.. versionadded:: 1.11.0
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float)
The lower and upper range of the bins. Required when `bins` is an integer.
Values outside the range are ignored. The first element of the range must
be less than or equal to the second.
normed : bool, optional
Not supported yet, coming soon.
weights : array_like, optional
Not supported yet, coming soon.
density : bool, optional
Not supported yet, coming soon.
"""
return _mx_nd_np.histogram(a, bins=bins, range=range, normed=normed, weights=weights, density=density)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import _register
from ._op import * # pylint: disable=wildcard-import
from ..context import * # pylint: disable=wildcard-import
from ..util import is_np_shape, is_np_array, set_np, reset_np
from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability
from ..ndarray import waitall
from .utils import * # pylint: disable=wildcard-import
from . import random # pylint: disable=wildcard-import
Expand Down
Loading

0 comments on commit 43e35a9

Please sign in to comment.