Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix sanity and flakiness of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Sep 2, 2019
1 parent d5393a2 commit 19c89f8
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,7 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable


@set_module('mxnet.ndarray.numpy')
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments
"""
Compute the standard deviation along the specified axis.
Returns the standard deviation, a measure of the spread of a distribution,
Expand Down Expand Up @@ -2271,7 +2271,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):


@set_module('mxnet.ndarray.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments
"""
Compute the variance along the specified axis.
Returns the variance of the array elements, a measure of the spread of a
Expand Down
90 changes: 88 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2570,15 +2570,101 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable


@set_module('mxnet.symbol.numpy')
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments
"""
Compute the standard deviation along the specified axis.
Returns the standard deviation, a measure of the spread of a distribution,
of the array elements. The standard deviation is computed for the
flattened array by default, otherwise over the specified axis.
Parameters
----------
a : `_Symbol`
_Symbol containing numbers whose standard deviation is desired.
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviations are computed.
The default is to compute the standard deviation of the flattened array.
If this is a tuple of ints, computation is performed over multiple axes,
instead of a single axis or all the axes as before.
dtype : data-type, optional
Type to use in computing the standard deviation. For integer inputs, the default is float32;
for floating point inputs, it is the same as the input dtype.
out : _Symbol, optional
Dummy parameter to keep the consistency with the ndarray counterpart.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result
as dimensions with size one. With this option, the result will broadcast correctly
against the input array.
If the default value is passed, then keepdims will not be passed through to the mean
method of sub-classes of _Symbol, however any non-default value will be. If the sub-class
method does not implement keepdims any exceptions will be raised.
Returns
-------
m : _Symbol, see dtype parameter above
If out=None, returns a new array containing the standard deviation values,
otherwise a reference to the output array is returned.
Notes
-----
This function differs from the original `numpy.std
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html>`_ in
the following way(s):
- only _Symbol is accepted as valid input, python iterables or scalar is not supported
- default output data type for integer input is float32
"""
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments
"""
Compute the variance along the specified axis.
Returns the variance of the array elements, a measure of the spread of a
distribution. The variance is computed for the flattened array by
default, otherwise over the specified axis.
Parameters
----------
a : `_Symbol`
_Symbol containing numbers whose variance is desired.
axis : None or int or tuple of ints, optional
Axis or axes along which the variance is computed.
The default is to compute the variance of the flattened array.
If this is a tuple of ints, computation is performed over multiple axes,
instead of a single axis or all the axes as before.
dtype : data-type, optional
Type to use in computing the variance. For integer inputs, the default is float32;
for floating point inputs, it is the same as the input dtype.
out : _Symbol, optional
Dummy parameter to keep the consistency with the ndarray counterpart.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result
as dimensions with size one. With this option, the result will broadcast correctly
against the input array.
If the default value is passed, then keepdims will not be passed through to the mean
method of sub-classes of _Symbol, however any non-default value will be. If the sub-class
method does not implement keepdims any exceptions will be raised.
Returns
-------
m : _Symbol, see dtype parameter above
If out=None, returns a new array containing the variance values,
otherwise a reference to the output array is returned.
Notes
-----
This function differs from the original `numpy.var
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html>`_ in
the following way(s):
- only _Symbol is accepted as valid input, python iterables or scalar is not supported
- default output data type for integer input is float32
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

Expand Down
3 changes: 1 addition & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def is_int(dtype):
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64']:
for dtype in ['float16', 'float32', 'float64']:
print(itype, dtype)
if is_int(dtype) and not is_int(itype):
continue
# test gluon
Expand Down Expand Up @@ -494,7 +493,7 @@ def legalize_shape(shape):
for dtype in ['float16', 'float32', 'float64']:
if is_int(dtype) and not is_int(itype) or is_int(itype) and is_int(dtype):
continue
atol = 1e-4 if itype == 'float16' or dtype == 'float16' else 1e-5
atol = 3e-4 if itype == 'float16' or dtype == 'float16' else 1e-5
rtol = 1e-2 if itype == 'float16' or dtype == 'float16' else 1e-3
# test gluon
test_moment = TestMoment(name, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)
Expand Down

0 comments on commit 19c89f8

Please sign in to comment.