From 6756d775de045d1e5bd36250fd116ca659ec8f4d Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 26 Aug 2019 16:22:00 -0700 Subject: [PATCH 1/8] Add maximum, minimum, swapaxes, argmax, clip in python --- python/mxnet/ndarray/numpy/_op.py | 187 ++++++++++++++++++++++++++- python/mxnet/numpy/multiarray.py | 187 ++++++++++++++++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 129 +++++++++++++++++- 3 files changed, 493 insertions(+), 10 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f0785a76818e..37f48f70c336 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,7 +32,8 @@ 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax'] @set_module('mxnet.ndarray.numpy') @@ -1960,3 +1961,187 @@ def get_list(arrays): arrays = get_list(arrays) return _npi.stack(*arrays, axis=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def maximum(x1, x2, out=None): + """Returns element-wise maximum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def minimum(x1, x2, out=None): + """Returns element-wise minimum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def swapaxes(a, axis1, axis2): + """Interchange two axes of an array. + + Parameters + ---------- + a : ndarray + Input array. + axis1 : int + First axis. + axis2 : int + Second axis. + + Returns + ------- + a_swapped : ndarray + Swapped array. This is always a copy of the input array. + """ + return _npi.swapaxes(a, dim1=axis1, dim2=axis2) + + +@set_module('mxnet.ndarray.numpy') +def clip(a, a_min, a_max, out=None): + """clip(a, a_min, a_max, out=None) + + Clip (limit) the values in an array. + Given an interval, values outside the interval are clipped to + the interval edges. For example, if an interval of ``[0, 1]`` + is specified, values smaller than 0 become 0, and values larger + than 1 become 1. + + Parameters + ---------- + a : ndarray + Array containing elements to clip. + a_min : scalar or `None` + Minimum value. If `None`, clipping is not performed on lower + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + a_max : scalar or `None` + Maximum value. If `None`, clipping is not performed on upper + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + out : ndarray, optional + The results will be placed in this array. It may be the input + array for in-place clipping. `out` must be of the right shape + to hold the output. Its type is preserved. + + Returns + ------- + clipped_array : ndarray + An array with the elements of `a`, but where values + < `a_min` are replaced with `a_min`, and those > `a_max` + with `a_max`. + + Notes + ----- + array_like `a_min` and `a_max` are not supported. + + Examples + -------- + >>> a = np.arange(10) + >>> np.clip(a, 1, 8) + array([1., 1., 2., 3., 4., 5., 6., 7., 8., 8.], dtype=float32) + >>> a + array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32) + >>> np.clip(a, 3, 6, out=a) + array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) + """ + if a_min is None and a_max is None: + raise ValueError('array_clip: must set either max or min') + if a_min is None: + a_min = float('-inf') + if a_max is None: + a_max = float('inf') + return _npi.clip(a, a_min, a_max, out=out) + + +@set_module('mxnet.ndarray.numpy') +def argmax(a, axis=None, out=None): + r""" + argmax(a, axis=None, out=None) + + Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : ndarray + Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : ndarray or None, optional + A location into which the result is stored. + If provided, it must have the same shape and dtype as input ndarray. + If not provided or `None`, a freshly-allocated array is returned. + + Returns + ------- + index_array : ndarray of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + + Notes + ----- + In case of multiple occurrences of the maximum values, the indices + corresponding to the first occurrence are returned. + + This function differs from the original `numpy.argmax + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - Output has dtype that is same as the input ndarray. + - ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output. + - ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output. + - ``out`` param does not support scalar input case. + + Examples + -------- + >>> a = np.arange(6).reshape(2,3) + 10 + >>> a + array([[10., 11., 12.], + [13., 14., 15.]]) + >>> np.argmax(a) + array(5.) + >>> np.argmax(a, axis=0) + array([1., 1., 1.]) + >>> np.argmax(a, axis=1) + array([2., 2.]) + + >>> b = np.arange(6) + >>> b[1] = 5 + >>> b + array([0., 5., 2., 3., 4., 5.]) + >>> np.argmax(b) # Only the first occurrence is returned. + array(1.) + + Specify ``out`` ndarray: + + >>> a = np.arange(6).reshape(2,3) + 10 + >>> b = np.zeros((2,)) + >>> np.argmax(a, axis=1, out=b) + array([2., 2.]) + >>> b + array([2., 2.]) + """ + return _npi.argmax(a, axis=axis, keepdims=False, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5e7129226e34..6272a688dc95 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -52,7 +52,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack'] + 'stack', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -751,7 +751,7 @@ def asscalar(self): raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar') def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ - raise NotImplementedError + return argmax(self, axis, out) def as_in_context(self, context): """Returns an array on the target device with the same value as this array. @@ -864,7 +864,7 @@ def swapaxes(self, axis1, axis2): # pylint: disable=arguments-differ """Return a copy of the array with axis1 and axis2 interchanged. Refer to `mxnet.numpy.swapaxes` for full documentation. """ - raise NotImplementedError + return swapaxes(self, axis1, axis2) def split(self, *args, **kwargs): """Convenience fluent method for :py:func:`split`. @@ -1045,12 +1045,12 @@ def argmin(self, *args, **kwargs): this array as data. """ raise NotImplementedError - + def clip(self, min=None, max=None, out=None): # pylint: disable=arguments-differ """Return an array whose values are limited to [min, max]. One of max or min must be given. """ - raise NotImplementedError + return clip(self, min, max, out=out) def abs(self, *args, **kwargs): """Convenience fluent method for :py:func:`abs`. @@ -3407,3 +3407,180 @@ def stack(arrays, axis=0, out=None): stacked : ndarray The stacked array has one more dimension than the input arrays.""" return _mx_nd_np.stack(arrays, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def maximum(x1, x2, out=None): + """Returns element-wise maximum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _mx_nd_np.maximum(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def minimum(x1, x2, out=None): + """Returns element-wise minimum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _mx_nd_np.minimum(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def swapaxes(a, axis1, axis2): + """Interchange two axes of an array. + + Parameters + ---------- + a : ndarray + Input array. + axis1 : int + First axis. + axis2 : int + Second axis. + + Returns + ------- + a_swapped : ndarray + Swapped array. This is always a copy of the input array. + """ + return _npi.swapaxes(a, dim1=axis1, dim2=axis2) + + +@set_module('mxnet.numpy') +def clip(a, a_min, a_max, out=None): + """clip(a, a_min, a_max, out=None) + + Clip (limit) the values in an array. + Given an interval, values outside the interval are clipped to + the interval edges. For example, if an interval of ``[0, 1]`` + is specified, values smaller than 0 become 0, and values larger + than 1 become 1. + + Parameters + ---------- + a : ndarray + Array containing elements to clip. + a_min : scalar or `None` + Minimum value. If `None`, clipping is not performed on lower + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + a_max : scalar or `None` + Maximum value. If `None`, clipping is not performed on upper + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + out : ndarray, optional + The results will be placed in this array. It may be the input + array for in-place clipping. `out` must be of the right shape + to hold the output. Its type is preserved. + + Returns + ------- + clipped_array : ndarray + An array with the elements of `a`, but where values + < `a_min` are replaced with `a_min`, and those > `a_max` + with `a_max`. + + Notes + ----- + array_like `a_min` and `a_max` are not supported. + + Examples + -------- + >>> a = np.arange(10) + >>> np.clip(a, 1, 8) + array([1., 1., 2., 3., 4., 5., 6., 7., 8., 8.], dtype=float32) + >>> a + array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32) + >>> np.clip(a, 3, 6, out=a) + array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) + """ + return _mx_nd_np.clip(a, a_min, a_max, out=out) + + +@set_module('mxnet.numpy') +def argmax(a, axis=None, out=None): + r""" + argmax(a, axis=None, out=None) + + Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : ndarray + Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : ndarray or None, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. + + Returns + ------- + index_array : ndarray of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + + Notes + ----- + In case of multiple occurrences of the maximum values, the indices + corresponding to the first occurrence are returned. + + This function differs from the original `numpy.argmax + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - Output has dtype that is same as the input ndarray. + - ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output. + - ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output. + - ``out`` param does not support scalar input case. + + Examples + -------- + >>> a = np.arange(6).reshape(2,3) + 10 + >>> a + array([[10., 11., 12.], + [13., 14., 15.]]) + >>> np.argmax(a) + array(5.) + >>> np.argmax(a, axis=0) + array([1., 1., 1.]) + >>> np.argmax(a, axis=1) + array([2., 2.]) + + >>> b = np.arange(6) + >>> b[1] = 5 + >>> b + array([0., 5., 2., 3., 4., 5.]) + >>> np.argmax(b) # Only the first occurrence is returned. + array(1.) + + Specify ``out`` ndarray: + + >>> a = np.arange(6).reshape(2,3) + 10 + >>> b = np.zeros((2,)) + >>> np.argmax(a, axis=1, out=b) + array([2., 2.]) + >>> b + array([2., 2.]) + """ + return _mx_nd_np.argmax(a, axis, out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index ff20cabdb748..8daf6b738662 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,7 +34,8 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax'] def _num_outputs(sym): @@ -248,7 +249,7 @@ def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ return _mx_np_op.reshape(self, newshape=args, order=order) def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ - raise NotImplementedError + return argmax(self, axis, out) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -298,7 +299,7 @@ def swapaxes(self, axis1, axis2): # pylint: disable=arguments-differ """Return a copy of the array with axis1 and axis2 interchanged. Refer to `mxnet.numpy.swapaxes` for full documentation. """ - raise NotImplementedError + return swapaxes(self, axis1, axis2) def split(self, *args, **kwargs): """Convenience fluent method for :py:func:`split`. @@ -408,7 +409,7 @@ def clip(self, min=None, max=None, out=None): # pylint: disable=arguments-diffe """Return an array whose values are limited to [min, max]. One of max or min must be given. """ - raise NotImplementedError + return clip(self, min, max, out=out) def abs(self, *args, **kwargs): """Convenience fluent method for :py:func:`abs`. @@ -2383,4 +2384,124 @@ def get_list(arrays): return _npi.stack(*arrays, axis=axis, out=out) +@set_module('mxnet.symbol.numpy') +def maximum(x1, x2, out=None): + return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def minimum(x1, x2, out=None): + return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) + + +@set_module('mxnet.symbol.numpy') +def clip(a, a_min, a_max, out=None): + """clip(a, a_min, a_max, out=None) + + Clip (limit) the values in an array. + Given an interval, values outside the interval are clipped to + the interval edges. For example, if an interval of ``[0, 1]`` + is specified, values smaller than 0 become 0, and values larger + than 1 become 1. + + Parameters + ---------- + a : _Symbol + Array containing elements to clip. + a_min : scalar or `None` + Minimum value. If `None`, clipping is not performed on lower + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + a_max : scalar or `None` + Maximum value. If `None`, clipping is not performed on upper + interval edge. Not more than one of `a_min` and `a_max` may be + `None`. + out : _Symbol or `None` + The results will be placed in this array. It may be the input + array for in-place clipping. `out` must be of the right shape + to hold the output. Its type is preserved. + + Returns + ------- + clipped_array : _Symbol + An array with the elements of `a`, but where values + < `a_min` are replaced with `a_min`, and those > `a_max` + with `a_max`. + + Notes + ----- + array_like `a_min` and `a_max` are not supported. + """ + if a_min is None and a_max is None: + raise ValueError('array_clip: must set either max or min') + if a_min is None: + a_min = float('-inf') + if a_max is None: + a_max = float('inf') + return _npi.clip(a, a_min, a_max, out=out) + + +@set_module('mxnet.symbol.numpy') +def swapaxes(a, axis1, axis2): + """Interchange two axes of an array. + + Parameters + ---------- + a : _Symbol + Input array. + axis1 : int + First axis. + axis2 : int + Second axis. + + Returns + ------- + a_swapped : _Symbol + Swapped array symbol. + """ + return _npi.swapaxes(a, dim1=axis1, dim2=axis2) + + +@set_module('mxnet.symbol.numpy') +def argmax(a, axis=None, out=None): + r""" + argmax(a, axis=None, out=None) + + Returns the indices of the maximum values along an axis. + + Parameters + ---------- + a : _Symbol + Input array. Only support dtype `float16`, `float32`, and `float64`. + axis : int, optional + By default, the index is into the flattened array, otherwise + along the specified axis. + out : _Symbol or None, optional + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + index_array : _Symbol of indices whose dtype is same as the input ndarray. + Array of indices into the array. It has the same shape as `a.shape` + with the dimension along `axis` removed. + + Notes + ----- + In case of multiple occurrences of the maximum values, the indices + corresponding to the first occurrence are returned. + + This function differs from the original `numpy.argmax + `_ in + the following aspects: + + - Input type does not support Python native iterables(list, tuple, ...). + - Output has dtype that is same as the input ndarray. + - ``out`` param: cannot perform auto broadcasting. ``out`` symbol's shape must be the same as the expected output. + - ``out`` param: cannot perform auto type cast. ``out`` symnbol's dtype must be the same as the expected output. + - ``out`` param does not support scalar input case. + + """ + return _npi.argmax(a, axis=axis, keepdims=False, out=out) + + _set_np_symbol_class(_Symbol) From f2dc04cb112ae359fdc5a4b0bf5af8211f57025d Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 26 Aug 2019 23:16:18 -0700 Subject: [PATCH 2/8] Add backend --- python/mxnet/numpy/multiarray.py | 2 + python/mxnet/symbol/numpy/_symbol.py | 2 + .../numpy/np_broadcast_reduce_op_index.cc | 61 ++++++ .../numpy/np_elemwise_broadcast_op.cc | 74 -------- .../tensor/elemwise_binary_broadcast_op.h | 4 + .../elemwise_binary_broadcast_op_extended.cc | 2 + .../elemwise_binary_scalar_op_extended.cc | 2 + tests/python/unittest/test_numpy_op.py | 179 ++++++++++++++++++ 8 files changed, 252 insertions(+), 74 deletions(-) create mode 100644 src/operator/numpy/np_broadcast_reduce_op_index.cc diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 6272a688dc95..d6243ff058a8 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -751,6 +751,8 @@ def asscalar(self): raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar') def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ + """Return indices of the maximum values along the given axis. + Refer to `mxnet.numpy.argmax` for full documentation.""" return argmax(self, axis, out) def as_in_context(self, context): diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 8daf6b738662..b1d36c57c048 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -249,6 +249,8 @@ def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ return _mx_np_op.reshape(self, newshape=args, order=order) def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ + """Return indices of the maximum values along the given axis. + Refer to `mxnet.numpy.argmax` for full documentation.""" return argmax(self, axis, out) def reshape_like(self, *args, **kwargs): diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cc b/src/operator/numpy/np_broadcast_reduce_op_index.cc new file mode 100644 index 000000000000..bd6915cc9b27 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_index.cc + * \brief CPU Implementation of broadcast and reduce functions based on index. + */ +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const ReduceAxisParam& param = nnvm::get(attrs.parsed); + dmlc::optional> axes; + if (param.axis.has_value()) { + mxnet::Tuple t({param.axis.value()}); + axes = dmlc::optional>(t); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], axes, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + +NNVM_REGISTER_OP(_npi_argmax) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxisShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.add_argument("data", "NDArray-or-Symbol", "The input") +.set_attr("FCompute", SearchAxisCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_arguments(ReduceAxisParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index c36423dff9fd..697657d84dd5 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -57,96 +57,22 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add) -.describe(R"code(Add arguments element-wise with broadcasting if necessary. - -Example:: - - x = [[ 1., 1., 1.], - [ 1., 1., 1.]] - - y = [[ 0.], - [ 1.]] - - add(x, y) = [[ 1., 1., 1.], - [ 2., 2., 2.]] - -)code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract) -.describe(R"code(Subtract arguments element-wise with broadcasting if necessary. - -Example:: - - x = [[ 1., 1., 1.], - [ 1., 1., 1.]] - - y = [[ 0.], - [ 1.]] - - subtract(x, y) = [[ 1., 1., 1.], - [ 0., 0., 0.]] - -)code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply) -.describe(R"code(Multiply arguments with broadcasting if necessary. - -Example:: - - x = [[ 1., 1., 1.], - [ 1., 1., 1.]] - - y = [[ 0.], - [ 1.]] - - multiply(x, y) = [[ 0., 0., 0.], - [ 1., 1., 1.]] - -)code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) -.describe(R"code(Return element-wise remainder of division. -It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2. - -Example:: - - x = [[ 8., 8., 8.], - [ 8., 8., 8.]] - - y = [[ 2.], - [ 3.]] - - mod(x, y) = [[ 0., 0., 0.], - [ 2., 2., 2.]] - -)code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) -.describe(R"code(First array elements raised to powers from second array, element-wise. - -Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be -broadcastable to the same shape. - -Example:: - - x = [[ 1., 1., 1.], - [ 1., 1., 1.]] - - y = [[ 0.], - [ 1.]] - - power(x, y) = [[ 2., 2., 2.], - [ 4., 4., 4.]] - -)code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 8a81bbc1c475..29c476d06d1c 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -599,6 +599,10 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + // skip kernel launch for zero-size tensors + if (inputs[0].shape_.Size() == 0U) { + return; + } mxnet::TShape new_lshape, new_rshape, new_oshape; const bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_, diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc index 842007eb497a..9e52b3197dcb 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc @@ -62,6 +62,7 @@ NNVM_REGISTER_OP(_backward_broadcast_power) mshadow_op::power_rgrad>); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_maximum) +.add_alias("_npi_maximum") .describe(R"code(Returns element-wise maximum of the input arrays with broadcasting. This function compares two input arrays and returns a new array having the element-wise maxima. @@ -97,6 +98,7 @@ NNVM_REGISTER_OP(_backward_broadcast_maximum) mshadow_op::lt>); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_minimum) +.add_alias("_npi_minimum") .describe(R"code(Returns element-wise minimum of the input arrays with broadcasting. This function compares two input arrays and returns a new array having the element-wise minima. diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 3a687c2aa062..f8f3e6a35201 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -29,6 +29,7 @@ namespace mxnet { namespace op { MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) +.add_alias("_npi_maximum_scalar") .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"}) .add_alias("_MaximumScalar"); @@ -39,6 +40,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) +.add_alias("_npi_minimum_scalar") .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"}) .add_alias("_MinimumScalar"); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0a1a4fb2b9b1..8d78c78ad1d7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1126,6 +1126,185 @@ def test_np_randint(): verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nrepeat=100) +def test_np_minimum_maximum(): + def check_symbol_output_type(op_name): + x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() + ret = getattr(mx.sym.np, op_name)(x1, x2) + assert type(ret) == mx.sym.np._Symbol + + def check_comp_op(op_name, x1, x2): + mx_out = getattr(np, op_name)(x1, x2) + if isinstance(x1, np.ndarray) or isinstance(x2, np.ndarray): + assert type(mx_out) == np.ndarray + np_out = getattr(_np, op_name)(x1.asnumpy() if isinstance(x1, np.ndarray) else x1, + x2.asnumpy() if isinstance(x2, np.ndarray) else x2) + assert same(mx_out.asnumpy() if isinstance(mx_out, np.ndarray) else mx_out, np_out) + + op_names = ['minimum', 'maximum'] + for op_name in op_names: + check_symbol_output_type(op_name) + check_comp_op(op_name, np.random.uniform(size=(2, 1)), np.random.uniform(size=(5, 1, 4))) + check_comp_op(op_name, np.random.uniform(size=(2, 0)), np.random.uniform(size=(5, 1, 1))) + check_comp_op(op_name, np.random.uniform(), np.random.uniform(size=(5, 1, 4))) + check_comp_op(op_name, _np.random.uniform(), np.random.uniform(size=(2, 3))) + check_comp_op(op_name, np.random.uniform(size=(2, 3)), _np.random.uniform()) + + +@with_seed() +@use_np +def test_np_swapaxes(): + config = [((0, 1, 2), 0, 1), + ((0, 1, 2), -1, -2), + ((4, 5, 6, 7), 2, 3), + ((4, 5, 6, 7), -2, -3)] + + class TestSwapaxes(HybridBlock): + def __init__(self, axis1, axis2): + super(TestSwapaxes, self).__init__() + self._axis1 = axis1 + self._axis2 = axis2 + + def hybrid_forward(self, F, x): + return F.np.swapaxes(x, self._axis1, self._axis2) + + for shape, axis1, axis2 in config: + data_np = _np.random.uniform(size=shape) + data_mx = np.array(data_np, dtype=data_np.dtype) + ret_np = _np.swapaxes(data_np, axis1=axis1, axis2=axis2) + ret_mx = np.swapaxes(data_mx, axis1=axis1, axis2=axis2) + assert same(ret_mx.asnumpy(), ret_np) + + net = TestSwapaxes(axis1, axis2) + for hybrid in [False, True]: + if hybrid: + net.hybridize() + ret_mx = net(data_mx) + assert same(ret_mx.asnumpy(), ret_np) + + +@with_seed() +@use_np +def test_np_argmax(): + workloads = [ + ((), 0, False), + ((), -1, False), + ((), 1, True), + ((5, 3), None, False), + ((5, 3), -1, False), + ((5, 3), 1, False), + ((5, 3), 3, True), + ((5, 0, 3), 0, False), + ((5, 0, 3), -1, False), + ((5, 0, 3), None, True), + ((5, 0, 3), 1, True), + ] + dtypes = ['float16', 'float32', 'float64'] + + class TestArgMax(HybridBlock): + def __init__(self, axis=None): + super(TestArgMax, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, x): + return F.np.argmax(x, self._axis) + + for shape, axis, throw_exception in workloads: + for dtype in dtypes: + a = np.random.uniform(size=shape, dtype=dtype) + if throw_exception: + # Cannot use assert_exception because sometimes the main thread + # proceeds to `assert False` before the exception is thrown + # in the worker thread. Have to use mx.nd.waitall() here + # to block the main thread. + try: + np.argmax(a, axis) + mx.nd.waitall() + assert False + except mx.MXNetError: + pass + else: + mx_ret = np.argmax(a, axis=axis) + np_ret = _np.argmax(a.asnumpy(), axis=axis) + assert same(mx_ret.asnumpy(), np_ret) + + for hybridize in [False, True]: + net = TestArgMax(axis) + if hybridize: + net.hybridize() + if throw_exception: + try: + net(a) + mx.nd.waitall() + assert False + except mx.MXNetError: + pass + else: + mx_ret = net(a) + assert same(mx_ret.asnumpy(), np_ret) + + +@with_seed() +@use_np +def test_np_clip(): + workloads = [ + ((), None, None, True), + ((), None, 1, False), + ((), -1, 1, False), + ((), -1, None, False), + ((5, 3), None, 0.1, False), + ((5, 3), -0.1, None, False), + ((5, 3), -0.1, 0.1, False), + ((5, 3), 0, 0, False), + ((5, 0, 3), 0, None, False), + ((5, 0, 3), None, -1, False), + ((5, 0, 3), -1, 0, False), + ] + dtypes = ['float32', 'float64'] + + class TestClip(HybridBlock): + def __init__(self, a_min=None, a_max=None): + super(TestClip, self).__init__() + self._a_min = a_min + self._a_max = a_max + + def hybrid_forward(self, F, x): + return x.clip(self._a_min, self._a_max) + + for shape, a_min, a_max, throw_exception in workloads: + for dtype in dtypes: + a = np.random.uniform(size=shape, dtype=dtype) + if throw_exception: + # Cannot use assert_exception because sometimes the main thread + # proceeds to `assert False` before the exception is thrown + # in the worker thread. Have to use mx.nd.waitall() here + # to block the main thread. + try: + a.clip(min=a_min, max=a_max) + mx.nd.waitall() + assert False + except: + pass + else: + mx_ret = a.clip(min=a_min, max=a_max) + np_ret = a.asnumpy().clip(min=a_min, max=a_max) + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) + + for hybridize in [False, True]: + net = TestClip(a_min, a_max) + if hybridize: + net.hybridize() + if throw_exception: + try: + net(a) + mx.nd.waitall() + assert False + except: + pass + else: + mx_ret = net(a) + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) + + if __name__ == '__main__': import nose nose.runmodule() From 1aacbe5a457840c5653a948a4d2a02ab7c2db991 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 26 Aug 2019 23:19:12 -0700 Subject: [PATCH 3/8] Fix pylint --- python/mxnet/numpy/multiarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index d6243ff058a8..749d560d58d7 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1047,7 +1047,7 @@ def argmin(self, *args, **kwargs): this array as data. """ raise NotImplementedError - + def clip(self, min=None, max=None, out=None): # pylint: disable=arguments-differ """Return an array whose values are limited to [min, max]. One of max or min must be given. From 1e1e33d48d269b0b2ea44b6f4763c4539d89a644 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 26 Aug 2019 23:30:08 -0700 Subject: [PATCH 4/8] Add unit test decorators back --- tests/python/unittest/test_numpy_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8d78c78ad1d7..5a38ad2a5e2d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1126,6 +1126,8 @@ def test_np_randint(): verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nrepeat=100) +@with_seed() +@use_np def test_np_minimum_maximum(): def check_symbol_output_type(op_name): x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() From 656dc57fdb403ee1af78277c27485e3c01231f43 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 27 Aug 2019 14:02:34 -0700 Subject: [PATCH 5/8] Fix gpu compile --- src/operator/numpy/np_elemwise_broadcast_op.cu | 12 ------------ .../tensor/elemwise_binary_scalar_op_extended.cc | 8 ++++---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index c858b3a4987a..ac8def2af2c2 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -42,12 +42,6 @@ NNVM_REGISTER_OP(_npi_mod) NNVM_REGISTER_OP(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_npi_maximum) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_npi_minimum) -.set_attr("FCompute", BinaryBroadcastCompute); - NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -72,11 +66,5 @@ NNVM_REGISTER_OP(_npi_power_scalar) NNVM_REGISTER_OP(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_npi_maximum_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_minimum_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index f8f3e6a35201..ba6ebccc5f01 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -29,10 +29,10 @@ namespace mxnet { namespace op { MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) -.add_alias("_npi_maximum_scalar") .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"}) -.add_alias("_MaximumScalar"); +.add_alias("_MaximumScalar") +.add_alias("_npi_maximum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) .add_argument("scalar", "float", "scalar value") @@ -40,10 +40,10 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) -.add_alias("_npi_minimum_scalar") .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"}) -.add_alias("_MinimumScalar"); +.add_alias("_MinimumScalar") +.add_alias("_npi_minimum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar) .add_argument("scalar", "float", "scalar value") From 941aa641b48d53dfa3815d6ef0a65eb9e5266883 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 27 Aug 2019 22:54:06 -0700 Subject: [PATCH 6/8] Add np.random.normal and npx.seed --- python/mxnet/ndarray/numpy/random.py | 52 ++++++++++++- python/mxnet/numpy/random.py | 38 +++++++++- python/mxnet/numpy_extension/__init__.py | 1 + python/mxnet/numpy_extension/random.py | 74 +++++++++++++++++++ python/mxnet/symbol/numpy/random.py | 52 ++++++++++++- .../numpy/np_broadcast_reduce_op_index.cu | 34 +++++++++ tests/python/unittest/test_numpy_op.py | 47 ++++++++++++ 7 files changed, 291 insertions(+), 7 deletions(-) create mode 100644 python/mxnet/numpy_extension/random.py create mode 100644 src/operator/numpy/np_broadcast_reduce_op_index.cu diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 71707d41c8e8..d892ccdaca73 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -19,9 +19,10 @@ from __future__ import absolute_import from ...context import current_context from . import _internal as _npi +from ...base import numeric_types -__all__ = ['randint', 'uniform'] +__all__ = ['randint', 'uniform', 'normal'] def randint(low, high=None, size=None, dtype=None, **kwargs): @@ -141,5 +142,50 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): return _npi.uniform(low=low, high=high, size=size, ctx=ctx, dtype=dtype, out=out) - raise ValueError( - "Distribution parameters must be either mxnet.numpy.ndarray or numbers") + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as ndarrays. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + if ctx is None: + ctx = current_context() + out = kwargs.pop('out', None) + if size is None and out is None: + size = () + if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): + raise NotImplementedError('np.random.normal only supports loc and scale of ' + 'numeric types for now') + return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index f0fd43eb0e70..dc6107476f81 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -21,7 +21,7 @@ from ..ndarray import numpy as _mx_nd_np -__all__ = ["randint", "uniform"] +__all__ = ["randint", "uniform", "normal"] def randint(low, high=None, size=None, dtype=None, **kwargs): @@ -108,3 +108,39 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): Drawn samples from the parameterized uniform distribution. """ return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out) + + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as ndarrays. + """ + return _mx_nd_np.random.normal(loc, scale, size, **kwargs) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index d71d65f08de2..4c26f59b980b 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -28,5 +28,6 @@ from ..util import is_np_shape, is_np_array, set_np, reset_np from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import +from .random import * # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py new file mode 100644 index 000000000000..bfe2270358bb --- /dev/null +++ b/python/mxnet/numpy_extension/random.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for ops used in imperative programming.""" + +from __future__ import absolute_import +from .. import random as _mx_rand + + +__all__ = ['seed'] + + +def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name + """Seeds the random number generators in MXNet. + + This affects the behavior of modules in MXNet that uses random number generators, + like the dropout operator and `ndarray`'s random sampling operators. + + Parameters + ---------- + seed : int + The random number seed. + + ctx : Context + The device context of the generator. The default is "all" which means seeding random + number generators of all devices. + + Notes + ----- + Random number generators in MXNet are device specific. + `mx.random.seed(seed_state)` sets the state of each generator using `seed_state` and the + device id. Therefore, random numbers generated from different devices can be different + even if they are seeded using the same seed. + + To produce identical random number sequences independent of the device id, + set optional `ctx` argument. This produces the same sequence of random numbers independent + of the device id, but the sequence can be different on different kind of devices as MXNet's + random number generators for CPU and GPU use different algorithms. + + Example + ------- + >>> from mxnet import np, npx + >>> npx.set_np() + >>> npx.random.seed(0) + >>> np.random.uniform() + array(0.5488135) + >>> npx.random.seed(128) + >>> np.random.uniform() + array(0.03812965) + >>> npx.random.seed(128) + >>> np.random.uniform() + array(0.03812965) + >>> npx.random.seed(128) + >>> np.random.uniform(ctx=npx.gpu(0)) + array(0.9894903, ctx=gpu(0)) + >>> npx.random.seed(128) + >>> np.random.uniform(ctx=npx.gpu(0)) + array(0.9894903, ctx=gpu(0)) + """ + _mx_rand.seed(seed_state=seed, ctx=ctx) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 86d0ba3095e1..c5b8e1dc4906 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -20,9 +20,10 @@ from __future__ import absolute_import from ...context import current_context from . import _internal as _npi +from ...base import numeric_types -__all__ = ['randint', 'uniform'] +__all__ = ['randint', 'uniform', 'normal'] def randint(low, high=None, size=None, dtype=None, **kwargs): @@ -142,5 +143,50 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): return _npi.uniform(low=low, high=high, size=size, ctx=ctx, dtype=dtype, out=out) - raise ValueError( - "Distribution parameters must be either mxnet.numpy.ndarray or numbers") + +def normal(loc=0.0, scale=1.0, size=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` + samples are drawn. If size is `None` (default), a scalar tensor containing + a single value is returned if loc and scale are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + ------- + out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs) + Drawn samples from the parameterized normal distribution. + + Notes + ----- + This function currently does not support ``loc`` and ``scale`` as `_Symbol`s. + """ + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + ctx = kwargs.pop('ctx', None) + if ctx is None: + ctx = current_context() + out = kwargs.pop('out', None) + if size is None and out is None: + size = () + if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): + raise NotImplementedError('np.random.normal only supports loc and scale of ' + 'numeric types for now') + return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu new file mode 100644 index 000000000000..a07baa9c070c --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_index.cu + * \brief GPU Implementation of broadcast and reduce functions based on index. + */ +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_argmax) +.set_attr("FCompute", SearchAxisCompute); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5a38ad2a5e2d..cedee6487f7d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1307,6 +1307,53 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) +@with_seed() +@use_np +def test_np_random(): + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + dtypes = ['float16', 'float32', 'float64'] + op_names = ['uniform', 'normal'] + op_names = ['normal'] + for shape in shapes: + for dtype in dtypes: + for op_name in op_names: + print('-------------------------------') + print(op_name) + print(shape) + print(dtype) + op = getattr(np.random, op_name, None) + assert op is not None + out = op(size=shape, dtype=dtype) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + class TestRandom(HybridBlock): + def __init__(self, shape, op_name): + super(TestRandom, self).__init__() + self._shape = shape + self._op_name = op_name + + def hybrid_forward(self, F, x): + op = getattr(F.np.random, self._op_name, None) + assert op is not None + return x + op(size=shape) + + x = np.ones(()) + for op_name in op_names: + for shape in shapes: + for hybridize in [False, True]: + net = TestRandom(shape, op_name) + if hybridize: + net.hybridize() + out = net(x) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + if __name__ == '__main__': import nose nose.runmodule() From 0ef20e176fc28cceeb4031517aa45726c862e6ac Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 28 Aug 2019 14:18:27 -0700 Subject: [PATCH 7/8] Expose seed through npx.random --- python/mxnet/numpy_extension/__init__.py | 2 +- python/mxnet/numpy_extension/random.py | 2 +- tests/python/unittest/test_numpy_op.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 4c26f59b980b..5a19d3dc4e2f 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -28,6 +28,6 @@ from ..util import is_np_shape, is_np_array, set_np, reset_np from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import -from .random import * # pylint: disable=wildcard-import +from . import random # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index bfe2270358bb..5aa58a0cc69d 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -42,7 +42,7 @@ def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name Notes ----- Random number generators in MXNet are device specific. - `mx.random.seed(seed_state)` sets the state of each generator using `seed_state` and the + `npx.random.seed(seed)` sets the state of each generator using `seed` and the device id. Therefore, random numbers generated from different devices can be different even if they are seeded using the same seed. diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index cedee6487f7d..bb6c86d3691f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1354,6 +1354,17 @@ def hybrid_forward(self, F, x): assert out.shape == expected_shape +@with_seed() +@use_np +def test_random_seed(): + for seed in [234, 594, 7240, 20394]: + ret = [] + for _ in range(2): + npx.random.seed(seed=seed) + ret.append(np.random.uniform(size=(2, 3))) + assert_almost_equal(ret[0].asnumpy(), ret[1].asnumpy(), use_broadcast=False) + + if __name__ == '__main__': import nose nose.runmodule() From ea223834c4233dbcee29b14f737176f96df886d7 Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 28 Aug 2019 14:21:13 -0700 Subject: [PATCH 8/8] Add rtol atol in seed testing --- tests/python/unittest/test_numpy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index bb6c86d3691f..c468a6b7a0fa 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1362,7 +1362,7 @@ def test_random_seed(): for _ in range(2): npx.random.seed(seed=seed) ret.append(np.random.uniform(size=(2, 3))) - assert_almost_equal(ret[0].asnumpy(), ret[1].asnumpy(), use_broadcast=False) + assert_almost_equal(ret[0].asnumpy(), ret[1].asnumpy(), rtol=1e-4, atol=1e-5, use_broadcast=False) if __name__ == '__main__':