diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index a3b4a2730378..e8332f1a83ef 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique'] + 'unique', 'lcm', 'tril', 'identity', 'take'] @set_module('mxnet.ndarray.numpy') @@ -50,7 +50,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): The shape of the empty array. dtype : str or numpy.dtype, optional An optional value type. Default is `numpy.float32`. Note that this - behavior is different from NumPy's `ones` function where `float64` + behavior is different from NumPy's `zeros` function where `float64` is the default value, because `float32` is considered as the default data type in deep learning. order : {'C'}, optional, default: 'C' @@ -96,7 +96,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ if order != 'C': raise NotImplementedError @@ -213,6 +213,134 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None): return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) +@set_module('mxnet.ndarray.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + + Examples + -------- + >>> np.identity(3) + >>> np.identity(3) + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + """ + if not isinstance(n, int): + raise TypeError("Input 'n' should be an integer") + if n < 0: + raise ValueError("Input 'n' cannot be negative") + if ctx is None: + ctx = current_context() + dtype = _np.float32 if dtype is None else dtype + return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) + + +# pylint: disable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : ndarray + The source array. + indices : ndarray + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : ndarray, optional + If provided, the result will be placed in this array. It should + be of the appropriate shape and dtype. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : ndarray + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + + Examples + -------- + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> np.take(a, indices) + array([4., 3., 6.]) + + In this example for `a` is an ndarray, "fancy" indexing can be used. + + >>> a[indices] + array([4., 3., 6.]) + + If `indices` is not one dimensional, the output also has these dimensions. + + >>> np.take(a, np.array([[0, 1], [2, 3]])) + array([[4., 3.], + [5., 7.]]) + """ + if mode not in ('wrap', 'clip', 'raise'): + raise NotImplementedError( + "function take does not support mode '{}'".format(mode)) + if axis: + return _npi.take(a, indices, axis, mode, out) + else: + return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) +# pylint: enable=redefined-outer-name + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. @@ -732,6 +860,79 @@ def expand_dims(a, axis): return _npi.expand_dims(a, axis) +@set_module('mxnet.ndarray.numpy') +def lcm(x1, x2, out=None): + """ + Returns the lowest common multiple of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing lowest common multiple. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The lowest common multiple of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + gcd : The greatest common divisor + + Examples + -------- + >>> np.lcm(12, 20) + 60 + >>> np.lcm(np.arange(6, dtype=int), 20) + array([ 0, 20, 20, 60, 20, 20], dtype=int64) + """ + return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def tril(m, k=0): + r""" + Lower triangle of an array. + + Return a copy of an array with elements above the `k`-th diagonal zeroed. + + Parameters + ---------- + m : ndarray, shape (M, N) + Input array. + k : int, optional + Diagonal above which to zero elements. `k = 0` (the default) is the + main diagonal, `k < 0` is below it and `k > 0` is above. + + Returns + ------- + tril : ndarray, shape (M, N) + Lower triangle of `m`, of same shape and data-type as `m`. + + See Also + -------- + triu : same thing, only for the upper triangle + + Examples + -------- + >>> a = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]]) + >>> np.tril(a, -1) + array([[ 0., 0., 0.], + [ 4., 0., 0.], + [ 7., 8., 0.], + [10., 11., 12.]]) + """ + return _npi.tril(m, k) + + def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs): """Helper function for unary operators. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 4972bdae7df6..7ba0f0d7d813 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,7 +54,7 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -1104,13 +1104,13 @@ def slice_assign(self, rhs, begin, end, step): """ return _npi.slice_assign(self, rhs, begin=begin, end=end, step=step, out=self) - def take(self, *args, **kwargs): + def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-differ, redefined-outer-name """Convenience fluent method for :py:func:`take`. The arguments are the same as for :py:func:`take`, with this array as data. """ - raise NotImplementedError + take(self, indices, axis, mode=mode) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -1754,7 +1754,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): The shape of the empty array. dtype : str or numpy.dtype, optional An optional value type (default is `numpy.float32`). Note that this - behavior is different from NumPy's `ones` function where `float64` + behavior is different from NumPy's `zeros` function where `float64` is the default value, because `float32` is considered as the default data type in deep learning. order : {'C'}, optional, default: 'C' @@ -1773,7 +1773,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): @set_module('mxnet.numpy') def ones(shape, dtype=_np.float32, order='C', ctx=None): - """Return a new array of given shape and type, filled with zeros. + """Return a new array of given shape and type, filled with ones. This function currently only supports storing multi-dimensional data in row-major (C-style). @@ -1795,7 +1795,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ return _mx_nd_np.ones(shape, dtype, order, ctx) @@ -1856,6 +1856,121 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin return _mx_nd_np.full(shape, fill_value, order=order, ctx=ctx, dtype=dtype, out=out) +@set_module('mxnet.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + + Examples + -------- + >>> np.identity(3) + >>> np.identity(3) + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + """ + return _mx_nd_np.identity(n, dtype, ctx) + + +# pylint: disable=redefined-outer-name +@set_module('mxnet.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : ndarray + The source array. + indices : ndarray + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : ndarray, optional + If provided, the result will be placed in this array. It should + be of the appropriate shape and dtype. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : ndarray + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + + Examples + -------- + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> np.take(a, indices) + array([4., 3., 6.]) + + In this example for `a` is an ndarray, "fancy" indexing can be used. + + >>> a[indices] + array([4., 3., 6.]) + + If `indices` is not one dimensional, the output also has these dimensions. + + >>> np.take(a, np.array([[0, 1], [2, 3]])) + array([[4., 3.], + [5., 7.]]) + """ + return _mx_nd_np.take(a, indices, axis, mode, out) +# pylint: enable=redefined-outer-name + + @set_module('mxnet.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ @@ -2131,6 +2246,43 @@ def power(x1, x2, out=None): return _mx_nd_np.power(x1, x2, out=out) +@set_module('mxnet.numpy') +def lcm(x1, x2, out=None): + """ + Returns the lowest common multiple of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing lowest common multiple. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The lowest common multiple of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + gcd : The greatest common divisor + + Examples + -------- + >>> np.lcm(12, 20) + 60 + >>> np.lcm(np.arange(6, dtype=int), 20) + array([ 0, 20, 20, 60, 20, 20], dtype=int64) + """ + return _mx_nd_np.lcm(x1, x2, out=out) + + @set_module('mxnet.numpy') def sin(x, out=None, **kwargs): r"""Trigonometric sine, element-wise. @@ -3621,6 +3773,42 @@ def tile(A, reps): return _mx_nd_np.tile(A, reps) +@set_module('mxnet.numpy') +def tril(m, k=0): + r""" + Lower triangle of an array. + + Return a copy of an array with elements above the `k`-th diagonal zeroed. + + Parameters + ---------- + m : ndarray, shape (M, N) + Input array. + k : int, optional + Diagonal above which to zero elements. `k = 0` (the default) is the + main diagonal, `k < 0` is below it and `k > 0` is above. + + Returns + ------- + tril : ndarray, shape (M, N) + Lower triangle of `m`, of same shape and data-type as `m`. + + See Also + -------- + triu : same thing, only for the upper triangle + + Examples + -------- + >>> a = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]]) + >>> np.tril(a, -1) + array([[ 0., 0., 0.], + [ 4., 0., 0.], + [ 7., 8., 0.], + [10., 11., 12.]]) + """ + return _mx_nd_np.tril(m, k) + + @set_module('mxnet.numpy') def arange(start, stop=None, step=1, dtype=None, ctx=None): """Return evenly spaced values within a given interval. diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 57b18ecaf547..9c055c401b31 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique'] + 'unique', 'lcm', 'tril', 'identity', 'take'] def _num_outputs(sym): @@ -347,13 +347,13 @@ def slice_like(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute slice_like') - def take(self, *args, **kwargs): + def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-differ, redefined-outer-name """Convenience fluent method for :py:func:`take`. The arguments are the same as for :py:func:`take`, with this array as data. """ - raise NotImplementedError + return take(self, indices, axis, mode=mode) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -906,7 +906,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): @set_module('mxnet.symbol.numpy') def ones(shape, dtype=_np.float32, order='C', ctx=None): - """Return a new array of given shape and type, filled with zeros. + """Return a new array of given shape and type, filled with ones. This function currently only supports storing multi-dimensional data in row-major (C-style). @@ -928,7 +928,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ if order != 'C': raise NotImplementedError @@ -993,6 +993,107 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) +@set_module('mxnet.symbol.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : _Symbol + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + """ + if not isinstance(n, int): + raise TypeError("Input 'n' should be an integer") + if n < 0: + raise ValueError("Input 'n' cannot be negative") + if ctx is None: + ctx = current_context() + dtype = _np.float32 if dtype is None else dtype + return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) + + +# pylint: disable=redefined-outer-name +@set_module('mxnet.symbol.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : _Symbol + The source array. + indices : _Symbol + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : _Symbol or None, optional + Dummy parameter to keep the consistency with the ndarray counterpart. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : _Symbol + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + """ + if mode not in ('wrap', 'clip', 'raise'): + raise NotImplementedError( + "function take does not support mode '{}'".format(mode)) + if axis: + return _npi.take(a, indices, axis, mode, out) + else: + return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) +# pylint: enable=redefined-outer-name + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. @@ -1079,6 +1180,36 @@ def power(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) +@set_module('mxnet.symbol.numpy') +def lcm(x1, x2, out=None): + """ + Returns the lowest common multiple of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing lowest common multiple. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The lowest common multiple of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + gcd : The greatest common divisor + """ + return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out) + + @set_module('mxnet.symbol.numpy') def tensordot(a, b, axes=2): r""" @@ -1229,6 +1360,33 @@ def expand_dims(a, axis): return _npi.expand_dims(a, axis) +@set_module('mxnet.symbol.numpy') +def tril(m, k=0): + r""" + Lower triangle of an array. + + Return a copy of an array with elements above the `k`-th diagonal zeroed. + + Parameters + ---------- + m : _Symbol, shape (M, N) + Input array. + k : int, optional + Diagonal above which to zero elements. `k = 0` (the default) is the + main diagonal, `k < 0` is below it and `k > 0` is above. + + Returns + ------- + tril : _Symbol, shape (M, N) + Lower triangle of `m`, of same shape and data-type as `m`. + + See Also + -------- + triu : same thing, only for the upper triangle + """ + return _npi.tril(m, k) + + def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs): """Helper function for unary operators. diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 6dae2dfa20c4..dc83a4b1f87f 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -186,6 +186,25 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +// Special case of ElemwiseType. Constrains dtype to integer types +template +inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(in_attrs->at(0) == mshadow::kInt64 || + in_attrs->at(0) == mshadow::kInt32 || + in_attrs->at(0) == mshadow::kInt8 || + in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types."; + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + } + return ElemwiseAttr( + attrs, in_attrs, out_attrs, -1); +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 6261638c03ec..92655c146193 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -55,6 +55,7 @@ using std::isnan; #endif using std::enable_if; using std::is_unsigned; +using std::is_integral; #define MXNET_UNARY_MATH_OP(name, expr) \ struct name : public mxnet_op::tunable { \ @@ -1088,6 +1089,48 @@ struct nanprod_grad : public mxnet_op::tunable { } }; +/*! \brief used for computing binary lowest common multiple */ +struct lcm : public mxnet_op::tunable { + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + // minus cases. + if (a < 0) { + a = -a; + } + if (b < 0) { + b = -b; + } + // handle zero-valued cases. + DType c; + if (a == 0 || b == 0) { + c = 0; + } else { + DType tmp; + DType tmp_a = a; + DType tmp_b = b; + if (a < b) { + tmp = a; + a = b; + b = tmp; + } + while (a % b != 0) { + a = a % b; + tmp = a; + a = b; + b = tmp; + } + c = tmp_a / b * tmp_b; + } + return c; + } + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + return DType(0.0f); + } +}; + } // namespace mshadow_op } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 16d4ef88f5c5..a786d1db5892 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -96,6 +96,23 @@ NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_lcm) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FCompute", BinaryBroadcastCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -263,5 +280,21 @@ NNVM_REGISTER_OP(_backward_npi_hypot) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 77525ce7acea..806debf431b5 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -45,6 +45,9 @@ NNVM_REGISTER_OP(_npi_power) NNVM_REGISTER_OP(_npi_copysign) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_npi_lcm) +.set_attr("FCompute", BinaryBroadcastCompute); + NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); @@ -112,5 +115,8 @@ NNVM_REGISTER_OP(_npi_rarctan2_scalar) NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index 4f031bdaa050..2477573c2413 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -71,6 +71,16 @@ NNVM_REGISTER_OP(_npi_ones) .set_attr("FCompute", FillCompute) .add_arguments(InitOpParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_identity) +.describe("Return a new identity array of given shape, type, and context.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", InitShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", IdentityCompute) +.add_arguments(InitOpParam::__FIELDS__()); + NNVM_REGISTER_OP(_np_zeros_like) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 49f1051735d8..e68dd9ad36a1 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -35,6 +35,9 @@ NNVM_REGISTER_OP(_npi_zeros) NNVM_REGISTER_OP(_npi_ones) .set_attr("FCompute", FillCompute); +NNVM_REGISTER_OP(_npi_identity) +.set_attr("FCompute", IdentityCompute); + NNVM_REGISTER_OP(_np_zeros_like) .set_attr("FCompute", FillCompute); diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 5c41820b57f8..3e1c345d59c3 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -20,8 +20,9 @@ /*! * Copyright (c) 2019 by Contributors * \file np_init_op.h - * \brief CPU Implementation of numpy init op + * \brief Function definition of numpy init op */ + #ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ @@ -65,6 +66,22 @@ struct indices_fwd { } }; +template +struct identity { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const int n) { + using namespace mxnet_op; + + const index_t row_id = i / n; + const index_t col_id = i % n; + if (row_id == col_id) { + KERNEL_ASSIGN(out_data[i], req, static_cast(1)); + } else { + KERNEL_ASSIGN(out_data[i], req, static_cast(0)); + } + } +}; + template void IndicesCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -101,6 +118,28 @@ void IndicesCompute(const nnvm::NodeAttrs& attrs, } } +template +void IdentityCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 0U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& out_data = outputs[0]; + int n = out_data.shape_[0]; + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, out_data.Size(), out_data.dptr(), n); + }); + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_tril_op-inl.h b/src/operator/numpy/np_tril_op-inl.h new file mode 100644 index 000000000000..1ad74e887b6c --- /dev/null +++ b/src/operator/numpy/np_tril_op-inl.h @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_tril_op-inl.h + * \brief Function definition of the tril (lower triangle of an array) op + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_TRIL_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_TRIL_OP_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct TrilParam : public dmlc::Parameter { + int k; + DMLC_DECLARE_PARAMETER(TrilParam) { + DMLC_DECLARE_FIELD(k) + .set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. " + "If input has shape (S0 S1) k must be between -S0 and S1"); + } +}; + +inline bool TrilOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& ishape = (*in_attrs)[0]; + mxnet::TShape oshape; + + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + + if (ishape.ndim() == 1) { + auto s = ishape[0]; + oshape = mxnet::TShape({s, s}); + } else { + oshape = ishape; + } + + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return shape_is_known(out_attrs->at(0)); +} + +template +struct tril1Dforward { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, + mshadow::Shape<2> oshape, int k) { + using namespace mxnet_op; + + const index_t row_id = i / oshape[1]; + const index_t col_id = i % oshape[1]; + if (col_id > (row_id + k)) { + KERNEL_ASSIGN(out[i], req, static_cast(0)); + } else { + KERNEL_ASSIGN(out[i], req, data[col_id]); + } + } +}; + +template +struct tril1Dbackward { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, + mshadow::Shape<1> oshape, int k) { + using namespace mxnet_op; + auto m = oshape[0]; + auto start = (i > k) ? (i - k) : 0; + DType res = 0; + for (auto y = start; y < m; y++) { + res += data[y * m + i]; + } + KERNEL_ASSIGN(out[i], req, res); + } +}; + +template +struct tril2D { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, + mshadow::Shape<2> oshape, int k) { + using namespace mxnet_op; + + const index_t row_id = i / oshape[1]; + const index_t col_id = i % oshape[1]; + if (col_id > (row_id + k)) { + KERNEL_ASSIGN(out[i], req, static_cast(0)); + } else { + KERNEL_ASSIGN(out[i], req, data[i]); + } + } +}; + +template +struct tril3D { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, + mshadow::Shape<3> oshape, int k) { + using namespace mxnet_op; + + const index_t row_id = i % (oshape[1] * oshape[2]) / oshape[2]; + const index_t col_id = i % (oshape[1] * oshape[2]) % oshape[2]; + if (col_id > (row_id + k)) { + KERNEL_ASSIGN(out[i], req, static_cast(0)); + } else { + KERNEL_ASSIGN(out[i], req, data[i]); + } + } +}; + +template +void TrilOpProcess(const TBlob& in_data, + const TBlob& out_data, + index_t dsize, + const TrilParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + using namespace mxnet_op; + using namespace mshadow; + + const mxnet::TShape& ishape = in_data.shape_; + const mxnet::TShape& oshape = out_data.shape_; + + if (ishape.ndim() == 2 && oshape.ndim() == 2) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape2(oshape[0], oshape[1]), param.k); + }); + }); + } else if (ishape.ndim() > 2) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + oshape.FlatTo3D(oshape.ndim() - 2), param.k); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (back) { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape1(oshape[0]), param.k); + } else { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape2(oshape[0], oshape[1]), param.k); + } + }); + }); + } +} + +template +void TrilOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const TrilParam& param = nnvm::get(attrs.parsed); + + TrilOpProcess(in_data, out_data, out_data.Size(), param, s, req); +} + +template +void TrilOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const TrilParam& param = nnvm::get(attrs.parsed); + + TrilOpProcess(in_data, out_data, out_data.Size(), param, s, req); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_TRIL_OP_INL_H_ diff --git a/src/operator/numpy/np_tril_op.cc b/src/operator/numpy/np_tril_op.cc new file mode 100644 index 000000000000..5c4b339b7768 --- /dev/null +++ b/src/operator/numpy/np_tril_op.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* Copyright (c) 2019 by Contributors +* \file np_tril_op.cc +* \brief CPU implementation of numpy tril operator +*/ + +#include "./np_tril_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(TrilParam); + +NNVM_REGISTER_OP(_npi_tril) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", TrilOpShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", TrilOpForward) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_tril"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(TrilParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_backward_tril) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", TrilOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_tril_op.cu b/src/operator/numpy/np_tril_op.cu new file mode 100644 index 000000000000..64613b505ded --- /dev/null +++ b/src/operator/numpy/np_tril_op.cu @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_tril_op.cu + * \brief GPU implementation of numpy tril operator + */ + +#include "./np_tril_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_tril) +.set_attr("FCompute", TrilOpForward); + +NNVM_REGISTER_OP(_backward_tril) +.set_attr("FCompute", TrilOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 1d644386cdbb..b81cd78ad507 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -369,6 +369,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 463e9f98820e..9961218b5482 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -288,6 +288,10 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = inputs[take_::kArr].shape_; const mxnet::TShape& oshape = outputs[take_::kOut].shape_; + if (idxshape.Size() == 0) { + return; + } + Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 9a46d894ee22..77d85d8e1e10 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -479,6 +479,10 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = inputs[take_::kArr].shape_; const mxnet::TShape& oshape = outputs[take_::kOut].shape_; + if (idxshape.Size() == 0) { + return; + } + Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 161acae0ebf2..16520ddbb242 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -670,9 +670,9 @@ struct TakeParam: public dmlc::Parameter { .set_default(take_::kClip) .describe("Specify how out-of-bound indices bahave. Default is \"clip\"." " \"clip\" means clip to the range. So, if all indices mentioned are too large," - " they are replaced by the index that addresses the last element along an axis. " - " \"wrap\" means to wrap around. " - " \"raise\" means to raise an error, not supported yet."); + " they are replaced by the index that addresses the last element along an axis." + " \"wrap\" means to wrap around." + " \"raise\" means to raise an error when index out of range."); } }; @@ -1030,6 +1030,10 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = outputs[0].shape_; const mxnet::TShape& oshape = inputs[0].shape_; + if (idxshape.Size() == 0) { + return; + } + if (req[take_::kIdx] != kNullOp) { mxnet_op::Kernel::Launch( s, idxshape.Size(), outputs[take_::kIdx].dptr()); diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 923ee53fc400..3a5e72b53d58 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -181,6 +181,51 @@ def check_ones_array_creation(shape, dtype): assert type(y[1]) == np.ndarray +@with_seed() +@use_np +def test_identity(): + class TestIdentity(HybridBlock): + def __init__(self, shape, dtype=None): + super(TestIdentity, self).__init__() + self._n = n + self._dtype = dtype + + def hybrid_forward(self, F, x): + return x * F.np.identity(self._n, self._dtype) + + class TestIdentityOutputType(HybridBlock): + def hybrid_forward(self, F, x): + return x, F.np.identity(0) + + def check_identity_array_creation(shape, dtype): + np_out = _np.identity(n=n, dtype=dtype) + mx_out = np.identity(n=n, dtype=dtype) + assert same(mx_out.asnumpy(), np_out) + if dtype is None: + assert mx_out.dtype == _np.float32 + assert np_out.dtype == _np.float64 + + ns = [0, 1, 2, 3, 5, 15, 30, 200] + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] + for n in ns: + for dtype in dtypes: + check_identity_array_creation(n, dtype) + x = mx.nd.array(_np.random.uniform(size=(n, n)), dtype=dtype).as_np_ndarray() + if dtype is None: + x = x.astype('float32') + for hybridize in [True, False]: + test_identity = TestIdentity(n, dtype) + test_identity_output_type = TestIdentityOutputType() + if hybridize: + test_identity.hybridize() + test_identity_output_type.hybridize() + y = test_identity(x) + assert type(y) == np.ndarray + assert same(x.asnumpy() * _np.identity(n, dtype), y.asnumpy()) + y = test_identity_output_type(x) + assert type(y[1]) == np.ndarray + + @with_seed() def test_np_ndarray_binary_element_wise_ops(): np_op_map = { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d2dc6ab269dc..7e3d9655f771 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -953,6 +953,67 @@ def hybrid_forward(self, F, x): assert same(ret_mx.asnumpy(), ret_np) +@with_seed() +@use_np +def test_np_tril(): + # numpy tril does not support scalar array (zero-dim) + config = [ + ((4, 2), 3), + ((4, 2), 9), + ((4, 2), 0), + ((4, 2), -1), + ((4, 5, 6), 0), + ((4, 5, 6), 5), + ((4, 5, 6), 2), + ((4, 5, 6), -2), + ((4, 5, 6), -5), + ((4, 0), 0), + ((4, 0), 2), + ((4, 0), 4), + ((4, 0), -3), + ((4, 0, 5), 0), + ((4, 0, 5), 1), + ((4, 0, 5), 5), + ((4, 0, 5), -3), + ((3, ), 0), + ((3, ), 2), + ((3, ), 5) + ] + + class TestTril(HybridBlock): + def __init__(self, k): + super(TestTril, self).__init__() + self._k = k + + def hybrid_forward(self, F, x): + return F.np.tril(x, k=self._k) + + for prefix in [1, -1]: + for shape, k in config: + data_np = _np.random.uniform(size=shape) + data_mx = np.array(data_np, dtype=data_np.dtype) + data_mx.attach_grad() + ret_np = _np.tril(data_np, k*prefix) + with mx.autograd.record(): + ret_mx = np.tril(data_mx, k*prefix) + assert same(ret_mx.asnumpy(), ret_np) + ret_mx.backward() + if len(shape) == 2: + grad_np = _np.tri(*shape, k=k*prefix) + assert same(data_mx.grad.asnumpy(), grad_np) + if len(shape) == 1: + grad_np = _np.tri(*shape, k=k*prefix) + grad_np = grad_np.sum(axis=0, keepdims=False) + assert same(data_mx.grad.asnumpy(), grad_np) + + net = TestTril(k*prefix) + for hybrid in [False, True]: + if hybrid: + net.hybridize() + ret_mx = net(data_mx) + assert same(ret_mx.asnumpy(), ret_np) + + @with_seed() @use_np def test_np_unary_funcs(): @@ -2600,6 +2661,155 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_lcm(): + shapes = [ + ((3, 1), (3,)), + ((3, 1), (3, 5)), + ((1, 4), (3, 1)), + ((), ()), + ((4, 0), ()), + ((3, 4, 5), ()), + ((), (3, 4, 5)), + ((3, 4, 5), (3, 1, 5)), + ((5, 1), (5, 2)) + ] + + class TestLcm(HybridBlock): + def __init__(self): + super(TestLcm, self).__init__() + + def hybrid_forward(self, F, x1, x2): + return F.np.lcm(x1, x2) + + for hybridize in [False]: + for shape in shapes: + test_lcm = TestLcm() + if hybridize: + test_lcm.hybridize() + + x1 = rand_ndarray(shape[0]).astype(_np.int32).as_np_ndarray() + x2 = rand_ndarray(shape[1]).astype(_np.int32).as_np_ndarray() + + np_out = _np.lcm(x1.asnumpy(), x2.asnumpy()) + mx_out = test_lcm(x1, x2) + + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # Test imperative once again + mx_out = np.lcm(x1, x2) + np_out = _np.lcm(x1.asnumpy(), x2.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + +@with_seed() +@use_np +def test_np_take(): + configs = [ + ((4, 4), (4, 0), None), + ((4, 4), (4, 0), 0), + ((4, 4), (4, 0), 1), + ((), (4, 0), None), + ((), (5, ), None), + ((), (4, 5), None), + ((), (), None), + ((3, 4), (), None), + ((3, 4), (), 0), + ((3, 4), (), 1), + ((3, 4, 5), (), 2), + ((3, 4, 5), (), -3), + ] + + class TestTake(HybridBlock): + def __init__(self, axis, mode): + super(TestTake, self).__init__() + self._axis = axis + self._mode = mode + + def hybrid_forward(self, F, a, indices): + return F.np.take(a, indices, axis=self._axis, mode=self._mode) + + def grad_helper(grad_in, axis, idx, mode): + k = grad_in.shape[axis] + if mode == 'clip': + idx = 0 if idx < 0 else idx + idx = k - 1 if idx >= k else idx + else: + idx = idx % k + if axis == None: + grad_in[idx] += 1.0 + elif axis == 0: + if axis == len(grad_in.shape) - 1: + grad_in[idx] += 1.0 + else: + grad_in[idx, :] += 1.0 + elif axis == 1: + if axis == len(grad_in.shape) - 1: + grad_in[:, idx] += 1.0 + else: + grad_in[:, idx, :] += 1.0 + elif axis == 2: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, idx] += 1.0 + else: + grad_in[:, :, idx, :] += 1.0 + elif axis == 3: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, :, idx] += 1.0 + else: + grad_in[:, :, :, idx, :] += 1.0 + elif axis == 4: + grad_in[:, :, :, :, idx] += 1.0 + else: + raise ValueError("axis %d is not supported..." % axis) + + def check_output_n_grad(data_shape, idx_shape, axis, mode): + data_real = _np.random.normal(size=data_shape).astype('float32') + idx_real = _np.random.randint(low=-100, high=100, size=idx_shape) + same(np.take(np.array(data_real), np.array(idx_real), axis=axis, mode=mode).asnumpy(), + _np.take(data_real, idx_real, axis=axis, mode=mode)) + + grad_in = _np.zeros(data_shape, dtype='float32') + + test_take = TestTake(axis=axis, mode=mode) + if hybridize: + test_take.hybridize() + x = np.array(data_real) + x.attach_grad() + with mx.autograd.record(): + mx_out = test_take(x, np.array(idx_real)) + same(mx_out.asnumpy(), _np.take(data_real, idx_real, axis=axis, mode=mode)) + + if axis and axis < 0: + axis += len(data_shape) + try: + for i in _np.nditer(idx_real): + grad_helper(grad_in, axis, i, mode) + except: + pass + + mx_out.backward() + same(x.grad.asnumpy(), grad_in) + + for hybridize in [True, False]: + for mode in ['clip', 'wrap']: + for data_ndim in range(1, 5): + for idx_ndim in range(1, 4): + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (_np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (_np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, mode) + + for config in configs: + check_output_n_grad(config[0], config[1], config[2], mode) + + if __name__ == '__main__': import nose nose.runmodule()