Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add nd.power and sym.pow
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Apr 3, 2019
1 parent 5f19362 commit 745f6a7
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 7 deletions.
56 changes: 55 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"pow", "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]

_STORAGE_TYPE_UNDEFINED = -1
Expand Down Expand Up @@ -3021,6 +3021,60 @@ def power(base, exp):
# pylint: enable= no-member, protected-access


def pow(base, exp):
"""Returns result of first array elements raised to powers from second array, element-wise
with broadcasting.
Equivalent to ``base ** exp`` and ``mx.nd.broadcast_power(lhs, rhs)``.
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
Parameters
----------
base : scalar or NDArray
The base array
exp : scalar or NDArray
The exponent array. If ``base.shape != exp.shape``, they must be
broadcastable to a common shape.
Returns
--------
NDArray
The bases in x raised to the exponents in y.
Examples
--------
>>> x = mx.nd.ones((2,3))*2
>>> y = mx.nd.arange(1,3).reshape((2,1))
>>> z = mx.nd.arange(1,3).reshape((2,1))
>>> x.asnumpy()
array([[ 2., 2., 2.],
[ 2., 2., 2.]], dtype=float32)
>>> y.asnumpy()
array([[ 1.],
[ 2.]], dtype=float32)
>>> z.asnumpy()
array([[ 1.],
[ 2.]], dtype=float32)
>>> (x**2).asnumpy()
array([[ 4., 4., 4.],
[ 4., 4., 4.]], dtype=float32)
>>> (x**y).asnumpy()
array([[ 2., 2., 2.],
[ 4., 4., 4.]], dtype=float32)
>>> mx.nd.pow(x,y).asnumpy()
array([[ 2., 2., 2.],
[ 4., 4., 4.]], dtype=float32)
>>> (z**y).asnumpy()
array([[ 1.],
[ 4.]], dtype=float32)
"""
power(base, exp)


def maximum(lhs, rhs):
"""Returns element-wise maximum of the input arrays with broadcasting.
Expand Down
39 changes: 38 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram", "split_v2"]


Expand Down Expand Up @@ -2780,6 +2780,43 @@ def pow(base, exp):
raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp))))


def power(base, exp):
"""Returns element-wise result of base element raised to powers from exp element.
Both inputs can be Symbol or scalar number.
Broadcasting is not supported. Use `broadcast_pow` instead.
Parameters
---------
base : Symbol or scalar
The base symbol
exp : Symbol or scalar
The exponent symbol
Returns
-------
Symbol or scalar
The bases in x raised to the exponents in y.
Examples
--------
>>> mx.sym.power(2, 3)
8
>>> x = mx.sym.Variable('x')
>>> y = mx.sym.Variable('y')
>>> z = mx.sym.power(x, 2)
>>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy()
array([ 1., 4.], dtype=float32)
>>> z = mx.sym.power(3, y)
>>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy()
array([ 9., 27.], dtype=float32)
>>> z = mx.sym.power(x, y)
>>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy()
array([ 9., 64.], dtype=float32)
"""
return pow(base, exp)


# pylint: disable=no-member
# pylint: disable=redefined-builtin
def maximum(left, right):
Expand Down
15 changes: 10 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,11 @@ def test_symbol_pow():
def test_pow_fn():
shape = (3, 4)
exp = mx.symbol.Variable("exp")
y = mx.sym.pow(2, exp)
x = np.ones(shape)*3
check_numeric_gradient(y, [x], numeric_eps=1E-3)
check_symbolic_forward(y, [x], [2**x])
check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])
for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]:
check_numeric_gradient(y, [x], numeric_eps=1E-3)
check_symbolic_forward(y, [x], [2**x])
check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])


@with_seed()
Expand Down Expand Up @@ -6520,7 +6520,12 @@ def test_binary_math_operators():
lambda x, y: np.power(x, y),
lambda x, y: np.power(x, y - 1.) * y,
lambda x, y: np.power(x, y) * np.log(x),
0.2, 5.0, -4.0, 4.0]
0.2, 5.0, -4.0, 4.0],
'power': [lambda x, y: mx.sym.power(x, y),
lambda x, y: np.power(x, y),
lambda x, y: np.power(x, y - 1.) * y,
lambda x, y: np.power(x, y) * np.log(x),
0.2, 5.0, -4.0, 4.0]
}
# Loop over operators
for name, op in binary_ops.items():
Expand Down

0 comments on commit 745f6a7

Please sign in to comment.