diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0c0a0a1e3c88..91d4ca16df07 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -47,7 +47,7 @@ from ._internal import SymbolBase, _set_symbol_class __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", - "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", + "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "histogram", "split_v2"] @@ -2740,6 +2740,8 @@ def pow(base, exp): Both inputs can be Symbol or scalar number. Broadcasting is not supported. Use `broadcast_pow` instead. + `sym.pow` is being deprecated, please use `sym.power` instead. + Parameters --------- base : Symbol or scalar @@ -2780,6 +2782,43 @@ def pow(base, exp): raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp)))) +def power(base, exp): + """Returns element-wise result of base element raised to powers from exp element. + + Both inputs can be Symbol or scalar number. + Broadcasting is not supported. Use `broadcast_pow` instead. + + Parameters + --------- + base : Symbol or scalar + The base symbol + exp : Symbol or scalar + The exponent symbol + + Returns + ------- + Symbol or scalar + The bases in x raised to the exponents in y. + + Examples + -------- + >>> mx.sym.power(2, 3) + 8 + >>> x = mx.sym.Variable('x') + >>> y = mx.sym.Variable('y') + >>> z = mx.sym.power(x, 2) + >>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy() + array([ 1., 4.], dtype=float32) + >>> z = mx.sym.power(3, y) + >>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy() + array([ 9., 27.], dtype=float32) + >>> z = mx.sym.power(x, y) + >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy() + array([ 9., 64.], dtype=float32) + """ + return pow(base, exp) + + # pylint: disable=no-member # pylint: disable=redefined-builtin def maximum(left, right): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 17618e414343..ccb351f434da 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -698,11 +698,11 @@ def test_symbol_pow(): def test_pow_fn(): shape = (3, 4) exp = mx.symbol.Variable("exp") - y = mx.sym.pow(2, exp) x = np.ones(shape)*3 - check_numeric_gradient(y, [x], numeric_eps=1E-3) - check_symbolic_forward(y, [x], [2**x]) - check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) + for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]: + check_numeric_gradient(y, [x], numeric_eps=1E-3) + check_symbolic_forward(y, [x], [2**x]) + check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) @with_seed() @@ -6675,7 +6675,12 @@ def test_binary_math_operators(): lambda x, y: np.power(x, y), lambda x, y: np.power(x, y - 1.) * y, lambda x, y: np.power(x, y) * np.log(x), - 0.2, 5.0, -4.0, 4.0] + 0.2, 5.0, -4.0, 4.0], + 'power': [lambda x, y: mx.sym.power(x, y), + lambda x, y: np.power(x, y), + lambda x, y: np.power(x, y - 1.) * y, + lambda x, y: np.power(x, y) * np.log(x), + 0.2, 5.0, -4.0, 4.0] } # Loop over operators for name, op in binary_ops.items():