Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add nd.power and sym.pow (#14606)
Browse files Browse the repository at this point in the history
* add nd.power and sym.pow

* deprecate sym.pow, get rid of nd.pow
  • Loading branch information
haojin2 authored Apr 11, 2019
1 parent 26b14bc commit 596ef3a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
41 changes: 40 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram", "split_v2"]


Expand Down Expand Up @@ -2740,6 +2740,8 @@ def pow(base, exp):
Both inputs can be Symbol or scalar number.
Broadcasting is not supported. Use `broadcast_pow` instead.
`sym.pow` is being deprecated, please use `sym.power` instead.
Parameters
---------
base : Symbol or scalar
Expand Down Expand Up @@ -2780,6 +2782,43 @@ def pow(base, exp):
raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp))))


def power(base, exp):
"""Returns element-wise result of base element raised to powers from exp element.
Both inputs can be Symbol or scalar number.
Broadcasting is not supported. Use `broadcast_pow` instead.
Parameters
---------
base : Symbol or scalar
The base symbol
exp : Symbol or scalar
The exponent symbol
Returns
-------
Symbol or scalar
The bases in x raised to the exponents in y.
Examples
--------
>>> mx.sym.power(2, 3)
8
>>> x = mx.sym.Variable('x')
>>> y = mx.sym.Variable('y')
>>> z = mx.sym.power(x, 2)
>>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy()
array([ 1., 4.], dtype=float32)
>>> z = mx.sym.power(3, y)
>>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy()
array([ 9., 27.], dtype=float32)
>>> z = mx.sym.power(x, y)
>>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy()
array([ 9., 64.], dtype=float32)
"""
return pow(base, exp)


# pylint: disable=no-member
# pylint: disable=redefined-builtin
def maximum(left, right):
Expand Down
15 changes: 10 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,11 @@ def test_symbol_pow():
def test_pow_fn():
shape = (3, 4)
exp = mx.symbol.Variable("exp")
y = mx.sym.pow(2, exp)
x = np.ones(shape)*3
check_numeric_gradient(y, [x], numeric_eps=1E-3)
check_symbolic_forward(y, [x], [2**x])
check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])
for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]:
check_numeric_gradient(y, [x], numeric_eps=1E-3)
check_symbolic_forward(y, [x], [2**x])
check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])


@with_seed()
Expand Down Expand Up @@ -6675,7 +6675,12 @@ def test_binary_math_operators():
lambda x, y: np.power(x, y),
lambda x, y: np.power(x, y - 1.) * y,
lambda x, y: np.power(x, y) * np.log(x),
0.2, 5.0, -4.0, 4.0]
0.2, 5.0, -4.0, 4.0],
'power': [lambda x, y: mx.sym.power(x, y),
lambda x, y: np.power(x, y),
lambda x, y: np.power(x, y - 1.) * y,
lambda x, y: np.power(x, y) * np.log(x),
0.2, 5.0, -4.0, 4.0]
}
# Loop over operators
for name, op in binary_ops.items():
Expand Down

0 comments on commit 596ef3a

Please sign in to comment.