Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add fluent methods mean, std, var for ndarray (#16077)
Browse files Browse the repository at this point in the history
* Add fluent methods mean, std, var for ndarray

* Fix sanity
  • Loading branch information
reminisce authored and haojin2 committed Sep 3, 2019
1 parent 692f3c4 commit 767e3f1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
8 changes: 6 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,11 +1170,15 @@ def nanprod(self, *args, **kwargs):

def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Returns the average of the array elements along given axis."""
raise NotImplementedError
return mean(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)

def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=arguments-differ
"""Returns the standard deviation of the array elements along given axis."""
return _mx_np_op.std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
return std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=arguments-differ
"""Returns the variance of the array elements, along given axis."""
return var(self, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims)

def cumsum(self, axis=None, dtype=None, out=None):
"""Return the cumulative sum of the elements along the given axis."""
Expand Down
14 changes: 9 additions & 5 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,16 @@ def nanprod(self, *args, **kwargs):
raise AttributeError('_Symbol object has no attribute nanprod')

def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`mean`.
"""Returns the average of the array elements along given axis."""
return mean(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)

The arguments are the same as for :py:func:`mean`, with
this array as data.
"""
return _npi.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=arguments-differ,too-many-arguments
"""Returns the standard deviation of the array elements along given axis."""
return std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=arguments-differ,too-many-arguments
"""Returns the variance of the array elements, along given axis."""
return var(self, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims)

def cumsum(self, axis=None, dtype=None, out=None):
"""Return the cumulative sum of the elements along the given axis."""
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __init__(self, axis=None, dtype=None, keepdims=False):
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
return a.mean(axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)

def is_int(dtype):
return 'int' in dtype
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(self, name, axis=None, dtype=None, keepdims=False, ddof=0):
self._ddof = ddof

def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.np, self._name)(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims, ddof=self._ddof)
return getattr(a, self._name)(axis=self._axis, dtype=self._dtype, keepdims=self._keepdims, ddof=self._ddof)

def is_int(dtype):
return 'int' in dtype
Expand Down

0 comments on commit 767e3f1

Please sign in to comment.