From e9e267ef711261f20528d443f38eb7b9e991057c Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 14 Sep 2019 09:33:08 -0700 Subject: [PATCH] Fix remaining errors reported by D2L (#16157) * Fix * Fix ssd * Add test for np.empty * Add np.linalg.norm * Fix indexing bug * Improve doc --- python/mxnet/_numpy_op_doc.py | 353 +++++++++++++++++++- python/mxnet/gluon/loss.py | 10 +- python/mxnet/ndarray/numpy/linalg.py | 48 ++- python/mxnet/numpy/linalg.py | 42 ++- python/mxnet/numpy/multiarray.py | 102 ++++-- python/mxnet/symbol/numpy/_symbol.py | 20 +- python/mxnet/symbol/numpy/linalg.py | 48 ++- tests/python/unittest/test_numpy_gluon.py | 38 ++- tests/python/unittest/test_numpy_ndarray.py | 71 +++- tests/python/unittest/test_numpy_op.py | 81 ++++- 10 files changed, 751 insertions(+), 62 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 4e5ea74c2771..c95752a00cb6 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -21,7 +21,10 @@ def _np_ones_like(a): - """Return an array of ones with the same shape and type as a given array. + """ + ones_like(a) + + Return an array of ones with the same shape and type as a given array. Parameters ---------- @@ -38,7 +41,10 @@ def _np_ones_like(a): def _np_zeros_like(a): - """Return an array of zeros with the same shape and type as a given array. + """ + zeros_like(a) + + Return an array of zeros with the same shape and type as a given array. Parameters ---------- @@ -55,7 +61,10 @@ def _np_zeros_like(a): def _np_cumsum(a, axis=None, dtype=None, out=None): - """Return the cumulative sum of the elements along a given axis. + """ + cumsum(a, axis=None, dtype=None, out=None) + + Return the cumulative sum of the elements along a given axis. Parameters ---------- @@ -103,3 +112,341 @@ def _np_cumsum(a, axis=None, dtype=None, out=None): """ pass + + +def _np_repeat(a, repeats, axis=None): + """ + repeat(a, repeats, axis=None) + + Repeat elements of an array. + + Parameters + ---------- + a : ndarray + Input array. + repeats : int + The number of repetitions for each element. + axis : int, optional + The axis along which to repeat values. By default, use the + flattened input array, and return a flat output array. + + Returns + ------- + repeated_array : ndarray + Output array which has the same shape as `a`, except along + the given axis. + + Notes + ----- + Unlike the official NumPy ``repeat`` operator, this operator currently + does not support array of ints for the parameter `repeats`. + + Examples + -------- + >>> x = np.arange(4).reshape(2, 2) + >>> x + array([[0., 1.], + [2., 3.]]) + >>> np.repeat(x, repeats=3) + array([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.]) + >>> np.repeat(x, repeats=3, axis=0) + array([[0., 1.], + [0., 1.], + [0., 1.], + [2., 3.], + [2., 3.], + [2., 3.]]) + >>> np.repeat(x, repeats=3, axis=1) + array([[0., 0., 0., 1., 1., 1.], + [2., 2., 2., 3., 3., 3.]]) + """ + pass + + +def _np_transpose(a, axes=None): + """ + transpose(a, axes=None) + + Permute the dimensions of an array. + + Parameters + ---------- + a : ndarray + Input array. + axes : list of ints, optional + By default, reverse the dimensions, + otherwise permute the axes according to the values given. + + Returns + ------- + p : ndarray + a with its axes permuted. + + Notes + ----- + This function differs from the original `numpy.transpose + `_ in + the following way(s): + + - only ndarray is accepted as valid input, python iterables are not supported + - the operator always returns an `ndarray` that does not share the memory with the input + + Examples + -------- + >>> x = np.arange(4).reshape((2,2)) + >>> x + array([[0., 1.], + [2., 3.]]) + >>> np.transpose(x) + array([[0., 2.], + [1., 3.]]) + >>> x = np.ones((1, 2, 3)) + >>> np.transpose(x, (1, 0, 2)).shape + (2, 1, 3) + """ + pass + + +def _np_dot(a, b, out=None): + """dot(a, b, out=None) + + Dot product of two arrays. Specifically, + + - If both `a` and `b` are 1-D arrays, it is inner product of vectors + + - If both `a` and `b` are 2-D arrays, it is matrix multiplication, + + - If either `a` or `b` is 0-D (scalar), it is equivalent to :func:`multiply` + and using ``np.multiply(a, b)`` or ``a * b`` is preferred. + + - If `a` is an N-D array and `b` is a 1-D array, it is a sum product over + the last axis of `a` and `b`. + + - If `a` is an N-D array and `b` is a 2-D array, it is a + sum product over the last axis of `a` and the second-to-last axis of `b`:: + + dot(a, b)[i,j,k] = sum(a[i,j,:] * b[:,k]) + + Parameters + ---------- + a : ndarray + First argument. + b : ndarray + Second argument. + + out : ndarray, optional + Output argument. It must have the same shape and type as the expected output. + + Returns + ------- + output : ndarray + Returns the dot product of `a` and `b`. If `a` and `b` are both + scalars or both 1-D arrays then a scalar is returned; otherwise + an array is returned. + If `out` is given, then it is returned + + Examples + -------- + >>> a = np.array(3) + >>> b = np.array(4) + >>> np.dot(a, b) + array(12.) + + For 2-D arrays it is the matrix product: + + >>> a = np.array([[1, 0], [0, 1]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.dot(a, b) + array([[4., 1.], + [2., 2.]]) + + >>> a = np.arange(3*4*5*6).reshape((3,4,5,6)) + >>> b = np.arange(5*6)[::-1].reshape((6,5)) + >>> np.dot(a, b)[2,3,2,2] + array(29884.) + >>> np.sum(a[2,3,2,:] * b[:,2]) + array(29884.) + """ + pass + + +def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None): + r""" + sum(a, axis=None, dtype=None, keepdims=_Null, initial=_Null, out=None) + + Sum of array elements over a given axis. + + Parameters + ---------- + a : ndarray + Input data. + axis : None or int, optional + Axis or axes along which a sum is performed. The default, + axis=None, will sum all of the elements of the input array. If + axis is negative it counts from the last to the first axis. + dtype : dtype, optional + The type of the returned array and of the accumulator in which the + elements are summed. The default type is float32. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `sum` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-classes `sum` method does not implement `keepdims` any + exceptions will be raised. + initial: Currently only supports None as input, optional + Starting value for the sum. + Currently not implemented. Please use ``None`` as input or skip this argument. + out : ndarray or None, optional + Alternative output array in which to place the result. It must have + the same shape and dtype as the expected output. + + Returns + ------- + sum_along_axis : ndarray + An ndarray with the same shape as `a`, with the specified + axis removed. If an output array is specified, a reference to + `out` is returned. + + Notes + ----- + - Input type does not support Python native iterables. + - "out" param: cannot perform auto type change. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use None as input. + - Arithmetic is modular when using integer types, and no error is raised on overflow. + - The sum of an empty array is the neutral element 0: + + >>> a = np.empty(1) + >>> np.sum(a) + array(0.) + + This function differs from the original `numpy.sum + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - "out" param: cannot perform auto type cast. out ndarray's dtype must be the same as the expected output. + - "initial" param is not supported yet. Please use ``None`` as input or skip it. + + Examples + -------- + >>> a = np.array([0.5, 1.5]) + >>> np.sum(a) + array(2.) + >>> a = np.array([0.5, 0.7, 0.2, 1.5]) + >>> np.sum(a, dtype=np.int32) + array(2, dtype=int32) + >>> a = np.array([[0, 1], [0, 5]]) + >>> np.sum(a) + array(6.) + >>> np.sum(a, axis=0) + array([0., 6.]) + >>> np.sum(a, axis=1) + array([1., 5.]) + + With output ndarray: + + >>> a = np.array([[0, 1], [0, 5]]) + >>> b = np.ones((2,), dtype=np.float32) + >>> np.sum(a, axis = 0, out=b) + array([0., 6.]) + >>> b + array([0., 6.]) + + If the accumulator is too small, overflow occurs: + + >>> np.ones(128, dtype=np.int8).sum(dtype=np.int8) + array(-128, dtype=int8) + """ + pass + + +def _np_copy(a, out=None): + """ + copy(a, out=None) + + Return an array copy of the given object. + + Parameters + ---------- + a : ndarray + Input data. + out : ndarray or None, optional + Alternative output array in which to place the result. It must have + the same shape and dtype as the expected output. + + Returns + ------- + arr : ndarray + Array interpretation of `a`. + + Notes + ------- + This function differs from the original `numpy.copy + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output. + - ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output. + - Does not support "order" parameter. + + Examples + -------- + Create an array x, with a reference y and a copy z: + + >>> x = np.array([1, 2, 3]) + >>> y = x + >>> z = np.copy(x) + + Note that, when ``x`` is modified, ``y`` is also modified, but not ``z``: + + >>> x[0] = 10 + >>> x[0] == y[0] + array([1.]) + >>> x[0] == z[0] + array([0.]) + """ + pass + + +def _np_reshape(a, newshape, order='C', out=None): + """ + reshape(a, newshape, order='C') + + Gives a new shape to an array without changing its data. + This function always returns a copy of the input array if + ``out`` is not provided. + + Parameters + ---------- + a : ndarray + Array to be reshaped. + newshape : int or tuple of ints + The new shape should be compatible with the original shape. If + an integer, then the result will be a 1-D array of that length. + One shape dimension can be -1. In this case, the value is + inferred from the length of the array and remaining dimensions. + order : {'C'}, optional + Read the elements of `a` using this index order, and place the + elements into the reshaped array using this index order. 'C' + means to read / write the elements using C-like index order, + with the last axis index changing fastest, back to the first + axis index changing slowest. Other order types such as 'F'/'A' + may be added in the future. + + Returns + ------- + reshaped_array : ndarray + It will be always a copy of the original array. This behavior is different + from the official NumPy ``reshape`` operator where views of the original array may be + generated. + + See Also + -------- + ndarray.reshape : Equivalent method. + """ + pass diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 4cdb361eb146..45c3dee08139 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -189,9 +189,15 @@ def __init__(self, weight=None, batch_axis=0, **kwargs): def hybrid_forward(self, F, pred, label, sample_weight=None): label = _reshape_like(F, label, pred) - loss = F.abs(label - pred) + loss = F.np.abs(label - pred) if is_np_array() else F.abs(label - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + if is_np_array(): + if F is ndarray: + return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) + else: + return F.npx.batch_flatten(loss).mean(axis=1) + else: + return F.mean(loss, axis=self._batch_axis, exclude=True) class SigmoidBinaryCrossEntropyLoss(Loss): diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 0222bb45d148..36f3f21a7588 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -18,5 +18,51 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import +from . import _op as _mx_nd_np -__all__ = [] +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + if ord is not None and ord != 'fro': + raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) + if isinstance(axis, tuple) and len(axis) > 2: + raise ValueError('Improper number of dimensions to norm') + if ord == 'fro' and x.ndim > 2 and axis is None: + raise ValueError('Improper number of dimensions to norm') + return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims)) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index c4109378e146..9758af47233d 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -18,5 +18,45 @@ """Namespace for ops used in imperative programming.""" from __future__ import absolute_import +from ..ndarray import numpy as _mx_nd_np -__all__ = [] +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + return _mx_nd_np.linalg.norm(x, ord, axis, keepdims) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 4eddb0380c09..1f8aa92f9851 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -59,6 +59,7 @@ _NDARRAY_BASIC_INDEXING = 0 _NDARRAY_ADVANCED_INDEXING = 1 + # This function is copied from ndarray.py since pylint # keeps giving false alarm error of undefined-all-variable def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): @@ -311,14 +312,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): Note: mxnet.numpy.ndarray not support NDArray as assigned value. """ if isinstance(value, numeric_types): - value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype) + value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype) elif isinstance(value, self.__class__): - value_nd = value.as_in_context(self.context) + value_nd = value.as_in_ctx(self.ctx) if value_nd.dtype != self.dtype: value_nd = value_nd.astype(self.dtype) else: try: - value_nd = array(value, ctx=self.context, dtype=self.dtype) + value_nd = array(value, ctx=self.ctx, dtype=self.dtype) except: raise TypeError('mxnet.np.ndarray does not support assignment with non-array-like ' 'object {} of type {}'.format(value, type(value))) @@ -329,6 +330,19 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)]) value_nd = value_nd.squeeze(axis=tuple(squeeze_axes)) + # handle the cases like the following + # a = np.zeros((3, 3)), b = np.ones((1, 1, 1, 1, 3)), a[0] = b + # b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed + if value_nd.ndim > len(bcast_shape): + squeeze_axes = [] + for i in range(value_nd.ndim - len(bcast_shape)): + if value_nd.shape[i] == 1: + squeeze_axes.append(i) + else: + break + if squeeze_axes: + value_nd = value_nd.squeeze(squeeze_axes) + if value_nd.shape != bcast_shape: if value_nd.size == 0: value_nd = value_nd.reshape(bcast_shape) @@ -336,7 +350,6 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): value_nd = value_nd.broadcast_to(bcast_shape) return value_nd - def __add__(self, other): """x.__add__(y) <=> x + y""" return add(self, other) @@ -727,14 +740,14 @@ def copyto(self, other): Examples -------- - >>> x = np.ones((2,3)) - >>> y = np.zeros((2,3), ctx=mx.gpu(0)) + >>> x = np.ones((2, 3)) + >>> y = np.zeros((2, 3), ctx=npx.gpu(0)) >>> z = x.copyto(y) >>> z is y True - >>> y.asnumpy() + >>> y array([[ 1., 1., 1.], - [ 1., 1., 1.]], dtype=float32) + [ 1., 1., 1.]]) """ if isinstance(other, ndarray): if other.handle is self.handle: @@ -756,6 +769,11 @@ def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ return argmax(self, axis, out) def as_in_context(self, context): + """This function has been deprecated. Please refer to ``ndarray.as_in_ctx``.""" + warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning) + return self.as_nd_ndarray().as_in_context(context).as_np_ndarray() + + def as_in_ctx(self, ctx): """Returns an array on the target device with the same value as this array. If the target context is the same as ``self.context``, then ``self`` is @@ -771,15 +789,58 @@ def as_in_context(self, context): ndarray The target array. """ - if self.context == context: + if self.ctx == ctx: return self - return self.copyto(context) + return self.copyto(ctx) + + @property + def ctx(self): + """Device context of the array. + + Examples + -------- + >>> x = np.array([1, 2, 3, 4]) + >>> x.ctx + cpu(0) + >>> type(x.ctx) + + >>> y = np.zeros((2, 3), npx.gpu(0)) + >>> y.ctx + gpu(0) + """ + dev_typeid = ctypes.c_int() + dev_id = ctypes.c_int() + check_call(_LIB.MXNDArrayGetContext( + self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id))) + return Context(Context.devtype2str[dev_typeid.value], dev_id.value) + + @property + def context(self): + """This function has been deprecated. Please refer to ``ndarray.ctx``.""" + warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning) + return self.as_nd_ndarray().context def copy(self, order='C'): # pylint: disable=arguments-differ + """Return a coyp of the array, keeping the same context. + + Parameters + ---------- + order : str + The memory layout of the copy. Currently, only c-contiguous memory + layout is supported. + + Examples + -------- + >>> x = np.ones((2, 3)) + >>> y = x.copy() + >>> y + array([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ if order != 'C': raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' 'received {}'.format(str(order))) - return super(ndarray, self).copy().as_np_ndarray() + return self.copyto(self.ctx) def dot(self, b, out=None): """Dot product of two arrays. @@ -787,7 +848,7 @@ def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ - """Returns an array containing the same data with a new shape. + """Returns a copy of the array with a new shape. Notes ----- @@ -854,7 +915,7 @@ def broadcast_axes(self, *args, **kwargs): def repeat(self, repeats, axis=None): # pylint: disable=arguments-differ """Repeat elements of an array.""" - raise NotImplementedError + return _mx_np_op.repeat(self, repeats=repeats, axis=axis) def pad(self, *args, **kwargs): """Convenience fluent method for :py:func:`pad`. @@ -1182,22 +1243,22 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint def cumsum(self, axis=None, dtype=None, out=None): """Return the cumulative sum of the elements along the given axis.""" - raise NotImplementedError + return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) def tolist(self): return self.asnumpy().tolist() def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Return the maximum along a given axis.""" - raise NotImplementedError + return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out) - def min(self, *args, **kwargs): + def min(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`min`. The arguments are the same as for :py:func:`min`, with this array as data. """ - raise NotImplementedError + return _mx_np_op.min(self, axis=axis, keepdims=keepdims, out=out) def norm(self, *args, **kwargs): """Convenience fluent method for :py:func:`norm`. @@ -1549,7 +1610,7 @@ def tostype(self, stype): @set_module('mxnet.numpy') -def empty(shape, dtype=float, order='C', ctx=None): +def empty(shape, dtype=_np.float32, order='C', ctx=None): """Return a new array of given shape and type, without initializing entries. Parameters @@ -1573,7 +1634,8 @@ def empty(shape, dtype=float, order='C', ctx=None): Array of uninitialized (arbitrary) data of the given shape, dtype, and order. """ if order != 'C': - raise NotImplementedError + raise NotImplementedError('`empty` only supports order equal to `C`, while received {}' + .format(str(order))) if ctx is None: ctx = current_context() if dtype is None: @@ -1609,7 +1671,7 @@ def array(object, dtype=None, ctx=None): if isinstance(object, ndarray): dtype = object.dtype if dtype is None else dtype else: - dtype = mx_real_t if dtype is None else dtype + dtype = _np.float32 if dtype is None else dtype if not isinstance(object, (ndarray, _np.ndarray)): try: object = _np.array(object, dtype=dtype) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0841c0e4d2cc..077008aba119 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -223,11 +223,11 @@ def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ - """Returns an array containing the same data with a new shape. + """Returns a copy of the array with a new shape. Notes ----- - Unlike the free function `numpy.reshape`, this method on `ndarray` allows + Unlike the free function `mxnet.numpy.reshape`, this method on `ndarray` allows the elements of the shape parameter to be passed in as separate arguments. For example, ``a.reshape(10, 11)`` is equivalent to ``a.reshape((10, 11))``. @@ -289,7 +289,7 @@ def broadcast_axes(self, *args, **kwargs): def repeat(self, repeats, axis=None): # pylint: disable=arguments-differ """Repeat elements of an array.""" - raise NotImplementedError + return _mx_np_op.repeat(self, repeats=repeats, axis=axis) def pad(self, *args, **kwargs): """Convenience fluent method for :py:func:`pad`. @@ -543,19 +543,15 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint def cumsum(self, axis=None, dtype=None, out=None): """Return the cumulative sum of the elements along the given axis.""" - raise NotImplementedError + return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Return the maximum along a given axis.""" - raise NotImplementedError - - def min(self, *args, **kwargs): - """Convenience fluent method for :py:func:`min`. + return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out) - The arguments are the same as for :py:func:`min`, with - this array as data. - """ - raise NotImplementedError + def min(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ + """Return the minimum along a given axis.""" + return _mx_np_op.min(self, axis=axis, keepdims=keepdims, out=out) def norm(self, *args, **kwargs): """Convenience fluent method for :py:func:`norm`. diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 28cfd0f3806a..d1918ef8b903 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -18,5 +18,51 @@ """Namespace for operators used in Gluon dispatched by F=symbol.""" from __future__ import absolute_import +from . import _symbol +from . import _op as _mx_sym_np -__all__ = [] +__all__ = ['norm'] + + +def norm(x, ord=None, axis=None, keepdims=False): + r"""Matrix or vector norm. + + This function can only support Frobenius norm for now. + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + Parameters + ---------- + x : ndarray + Input array. + ord : {'fro'}, optional + Order of the norm. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None, the norm of the whole ndarray is + returned. + + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix or vector(s). + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + """ + if ord is not None and ord != 'fro': + raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) + if isinstance(axis, tuple) and len(axis) > 2: + raise ValueError('Improper number of dimensions to norm') + # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception + return _symbol.sqrt(_mx_sym_np.sum(x * x, axis=axis, keepdims=keepdims)) diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index e96f829a8580..62ea38fc0c13 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -19,9 +19,11 @@ from __future__ import absolute_import from __future__ import division +import numpy as _np import mxnet as mx from mxnet import gluon, autograd, np -from mxnet.test_utils import use_np +from mxnet.test_utils import use_np, assert_almost_equal +from common import with_seed def test_create_np_param(): @@ -108,6 +110,40 @@ def hybrid_forward(self, F, pred, label): trainer.step(1) +@with_seed() +@use_np +def test_np_loss_ndarray(): + # Ported from test_loss.test_loss_ndarray + output = np.array([1, 2, 3, 4]) + label = np.array([1, 3, 5, 7]) + weighting = np.array([0.5, 1, 0.5, 1]) + + loss = gluon.loss.L1Loss() + assert np.sum(loss(output, label)) == 6. + loss = gluon.loss.L1Loss(weight=0.5) + assert np.sum(loss(output, label)) == 3. + loss = gluon.loss.L1Loss() + assert np.sum(loss(output, label, weighting)) == 5. + + loss = gluon.loss.L2Loss() + assert np.sum(loss(output, label)) == 7. + loss = gluon.loss.L2Loss(weight=0.25) + assert np.sum(loss(output, label)) == 1.75 + loss = gluon.loss.L2Loss() + assert np.sum(loss(output, label, weighting)) == 6 + + output = np.array([[0, 2], [1, 4]]) + label = np.array([0, 1]) + weighting = np.array([[0.5], [1.0]]) + + loss = gluon.loss.SoftmaxCrossEntropyLoss() + L = loss(output, label).asnumpy() + assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False) + + L = loss(output, label, weighting).asnumpy() + assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 883060466836..bffa7a00dccb 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -26,12 +26,45 @@ from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np from common import with_seed, TemporaryDirectory from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +from mxnet.ndarray.ndarray import py_slice +from mxnet.base import integer_types import scipy.stats as ss @with_seed() @use_np -def test_array_creation(): +def test_np_empty(): + dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, None] + expected_dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, np.float32] + orders = ['C', 'F', 'A'] + shapes = [ + (), + 0, + (0,), + (0, 0), + 2, + (2,), + (3, 0), + (4, 5), + (1, 1, 1, 1), + ] + ctxes = [npx.current_context(), None] + for dtype, expected_dtype in zip(dtypes, expected_dtypes): + for shape in shapes: + for order in orders: + for ctx in ctxes: + if order == 'C': + ret = np.empty(shape, dtype, order, ctx) + assert ret.dtype == expected_dtype + assert ret.shape == shape if isinstance(shape, tuple) else (shape,) + assert ret.ctx == npx.current_context() + else: + assert_exception(np.empty, NotImplementedError, shape, dtype, order, ctx) + + +@with_seed() +@use_np +def test_np_array_creation(): dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] objects = [ [], @@ -54,7 +87,7 @@ def test_array_creation(): @with_seed() @use_np -def test_zeros(): +def test_np_zeros(): # test np.zeros in Gluon class TestZeros(HybridBlock): def __init__(self, shape, dtype=None): @@ -102,7 +135,7 @@ def check_zero_array_creation(shape, dtype): @with_seed() @use_np -def test_ones(): +def test_np_ones(): # test np.ones in Gluon class TestOnes(HybridBlock): def __init__(self, shape, dtype=None): @@ -149,7 +182,7 @@ def check_ones_array_creation(shape, dtype): @with_seed() -def test_ndarray_binary_element_wise_ops(): +def test_np_ndarray_binary_element_wise_ops(): np_op_map = { '+': _np.add, '*': _np.multiply, @@ -304,7 +337,7 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): @with_seed() -def test_hybrid_block_multiple_outputs(): +def test_np_hybrid_block_multiple_outputs(): @use_np class TestAllNumpyOutputs(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): @@ -338,7 +371,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): @with_seed() @use_np -def test_grad_ndarray_type(): +def test_np_grad_ndarray_type(): data = np.array(2, dtype=_np.float32) data.attach_grad() assert type(data.grad) == np.ndarray @@ -379,7 +412,7 @@ def test_np_ndarray_copy(): def test_np_ndarray_indexing(): def np_int(index, int_type=np.int32): """ - Helper function for testing indexing that converts slices to slices of ints or None, and tuples to + Helper function for testing indexing that converts slices to slices of ints or None, and tuples to tuples of ints or None. """ def convert(num): @@ -401,7 +434,7 @@ def convert(num): else: assert False - # Copied from test_ndarray.py. Under construction. + # Copied from test_ndarray.py. Under construction. def test_getitem(np_array, index): np_index = index if type(index) == mx.nd.NDArray: # use of NDArray is prohibited @@ -439,6 +472,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) assert same(np_array, mx_array.asnumpy()) + def _is_basic_index(index): + if isinstance(index, (integer_types, py_slice)): + return True + if isinstance(index, tuple) and all(isinstance(i, (integer_types, py_slice)) for i in index): + return True + return False + np_index = index # keep this native numpy type if isinstance(index, np.ndarray): np_index = index.asnumpy() @@ -467,6 +507,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) assert_same(np_array, np_index, mx_array, index, np.array(np_value)) # test native numpy array with broadcast assert_same(np_array, np_index, mx_array, index, np_value) + + # test value shape are expanded to be longer than index array's shape + # this is currently only supported in basic indexing + if _is_basic_index(index): + expanded_value_shape = (1, 1, 1) + np_value.shape + assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape))) + assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape)) # test list with broadcast assert_same(np_array, np_index, mx_array, index, [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) @@ -664,18 +711,18 @@ def test_setitem_autograd(np_array, index): test_setitem(np_array, index) test_getitem_autograd(np_array, index) test_setitem_autograd(np_array, index) - + # Test indexing to zero-size tensors index_list = [ - (slice(0, 0), slice(0, 0), 1, 2), - (slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), + (slice(0, 0), slice(0, 0), 1, 2), + (slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), ] for index in index_list: test_getitem(np_array, index) test_setitem(np_array, index) test_getitem_autograd(np_array, index) test_setitem_autograd(np_array, index) - + # test zero-size tensors get and setitem shapes_indices = [ ((0), [slice(None, None, None)]), diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0f7355b7ef13..c5b0907fb7a8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -288,7 +288,7 @@ def __init__(self, axis=None, keepdims=False): self._keepdims = keepdims def hybrid_forward(self, F, a, *args, **kwargs): - return F.np.max(a, axis=self._axis, keepdims=self._keepdims) + return a.max(axis=self._axis, keepdims=self._keepdims) class TestMin(HybridBlock): def __init__(self, axis=None, keepdims=False): @@ -297,7 +297,7 @@ def __init__(self, axis=None, keepdims=False): self._keepdims = keepdims def hybrid_forward(self, F, a, *args, **kwargs): - return F.np.min(a, axis=self._axis, keepdims=self._keepdims) + return a.min(axis=self._axis, keepdims=self._keepdims) def is_int(dtype): return 'int' == dtype @@ -326,12 +326,8 @@ def get_grad(axis, func_name): raise ValueError('axis should be int or None or ()') def _test_np_exception(func, shape, dim): - x = _np.random.uniform(-1.0, 1.0, shape) - x = mx.nd.array(x).as_np_ndarray() - if func == 'max': - out = mx.np.max(x) - else: - out = mx.np.min(x) + x = np.random.uniform(-1.0, 1.0, shape) + out = getattr(x, func)() assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim) in_data_dim = random.choice([2, 3, 4]) @@ -1620,7 +1616,7 @@ def __init__(self, axis=None, dtype=None): self._dtype = dtype def hybrid_forward(self, F, a): - return F.np.cumsum(a, axis=self._axis, dtype=self._dtype) + return a.cumsum(axis=self._axis, dtype=self._dtype) shapes = [(2, 3, 4), (2, 0, 3), ()] for hybridize in [True, False]: @@ -1790,6 +1786,73 @@ def hybrid_forward(self, F, x): assert mx_out.shape == np_out.shape +@with_seed() +@use_np +def test_np_repeat(): + config = [ + ((), 2, None), + ((), 0, None), + ((4, 2), 2, None), + ((4, 2), 2, 0), + ((4, 2), 2, 1), + ((4, 2), 2, -1), + ] + + class TestRepeat(HybridBlock): + def __init__(self, repeats, axis=None): + super(TestRepeat, self).__init__() + self._repeats = repeats + self._axis = axis + + def hybrid_forward(self, F, x): + return x.repeat(self._repeats, self._axis) + + for shape, repeats, axis in config: + data_np = _np.random.randint(low=0, high=1000, size=shape) + data_mx = np.array(data_np, dtype=data_np.dtype) + ret_np = data_np.repeat(repeats, axis) + ret_mx = data_mx.repeat(repeats, axis) + assert same(ret_mx.asnumpy(), ret_np) + + net = TestRepeat(repeats, axis) + for hybrid in [False, True]: + if hybrid: + net.hybridize() + ret_mx = net(data_mx) + assert same(ret_mx.asnumpy(), ret_np) + + +@with_seed() +@use_np +def test_np_linalg_norm(): + @use_np + class TestLinalgNorm(HybridBlock): + def __init__(self, ord=None, axis=None, keepdims=False): + super(TestLinalgNorm, self).__init__() + self._ord = ord + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, x): + return F.np.linalg.norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims) + + a = np.arange(5 * 6 * 7 * 8).reshape((5, 6, 7, 8)) + ords = [None, 'fro'] + axes = [None, (0, 2), (1, 0), (1, 2)] + for ord in ords: + for axis in axes: + if ord == 'fro' and axis is None and a.ndim > 2: + continue + for keepdims in [False, True]: + for hybridize in [False, True]: + net = TestLinalgNorm(ord, axis, keepdims) + if hybridize: + net.hybridize() + mx_ret = net(a) + np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4) + + if __name__ == '__main__': import nose nose.runmodule()