From 745f6a744654fc82eaf526f9e9321823f72aa1ef Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 3 Apr 2019 06:01:01 +0000 Subject: [PATCH] add nd.power and sym.pow --- python/mxnet/ndarray/ndarray.py | 56 +++++++++++++++++++++++++- python/mxnet/symbol/symbol.py | 39 +++++++++++++++++- tests/python/unittest/test_operator.py | 15 ++++--- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index acb7b283aa76..cc17c1daf269 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -46,7 +46,7 @@ "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode", - "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram", + "pow", "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram", "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"] _STORAGE_TYPE_UNDEFINED = -1 @@ -3021,6 +3021,60 @@ def power(base, exp): # pylint: enable= no-member, protected-access +def pow(base, exp): + """Returns result of first array elements raised to powers from second array, element-wise + with broadcasting. + + Equivalent to ``base ** exp`` and ``mx.nd.broadcast_power(lhs, rhs)``. + + .. note:: + + If the corresponding dimensions of two arrays have the same size or one of them has size 1, + then the arrays are broadcastable to a common shape. + + Parameters + ---------- + base : scalar or NDArray + The base array + exp : scalar or NDArray + The exponent array. If ``base.shape != exp.shape``, they must be + broadcastable to a common shape. + + Returns + -------- + NDArray + The bases in x raised to the exponents in y. + + Examples + -------- + >>> x = mx.nd.ones((2,3))*2 + >>> y = mx.nd.arange(1,3).reshape((2,1)) + >>> z = mx.nd.arange(1,3).reshape((2,1)) + >>> x.asnumpy() + array([[ 2., 2., 2.], + [ 2., 2., 2.]], dtype=float32) + >>> y.asnumpy() + array([[ 1.], + [ 2.]], dtype=float32) + >>> z.asnumpy() + array([[ 1.], + [ 2.]], dtype=float32) + >>> (x**2).asnumpy() + array([[ 4., 4., 4.], + [ 4., 4., 4.]], dtype=float32) + >>> (x**y).asnumpy() + array([[ 2., 2., 2.], + [ 4., 4., 4.]], dtype=float32) + >>> mx.nd.pow(x,y).asnumpy() + array([[ 2., 2., 2.], + [ 4., 4., 4.]], dtype=float32) + >>> (z**y).asnumpy() + array([[ 1.], + [ 4.]], dtype=float32) + """ + power(base, exp) + + def maximum(lhs, rhs): """Returns element-wise maximum of the input arrays with broadcasting. diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0c0a0a1e3c88..457a3ef23b7b 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -47,7 +47,7 @@ from ._internal import SymbolBase, _set_symbol_class __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", - "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", + "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "histogram", "split_v2"] @@ -2780,6 +2780,43 @@ def pow(base, exp): raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp)))) +def power(base, exp): + """Returns element-wise result of base element raised to powers from exp element. + + Both inputs can be Symbol or scalar number. + Broadcasting is not supported. Use `broadcast_pow` instead. + + Parameters + --------- + base : Symbol or scalar + The base symbol + exp : Symbol or scalar + The exponent symbol + + Returns + ------- + Symbol or scalar + The bases in x raised to the exponents in y. + + Examples + -------- + >>> mx.sym.power(2, 3) + 8 + >>> x = mx.sym.Variable('x') + >>> y = mx.sym.Variable('y') + >>> z = mx.sym.power(x, 2) + >>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy() + array([ 1., 4.], dtype=float32) + >>> z = mx.sym.power(3, y) + >>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy() + array([ 9., 27.], dtype=float32) + >>> z = mx.sym.power(x, y) + >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy() + array([ 9., 64.], dtype=float32) + """ + return pow(base, exp) + + # pylint: disable=no-member # pylint: disable=redefined-builtin def maximum(left, right): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c9498ecb0bd2..234eb769cb50 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -698,11 +698,11 @@ def test_symbol_pow(): def test_pow_fn(): shape = (3, 4) exp = mx.symbol.Variable("exp") - y = mx.sym.pow(2, exp) x = np.ones(shape)*3 - check_numeric_gradient(y, [x], numeric_eps=1E-3) - check_symbolic_forward(y, [x], [2**x]) - check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) + for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]: + check_numeric_gradient(y, [x], numeric_eps=1E-3) + check_symbolic_forward(y, [x], [2**x]) + check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) @with_seed() @@ -6520,7 +6520,12 @@ def test_binary_math_operators(): lambda x, y: np.power(x, y), lambda x, y: np.power(x, y - 1.) * y, lambda x, y: np.power(x, y) * np.log(x), - 0.2, 5.0, -4.0, 4.0] + 0.2, 5.0, -4.0, 4.0], + 'power': [lambda x, y: mx.sym.power(x, y), + lambda x, y: np.power(x, y), + lambda x, y: np.power(x, y - 1.) * y, + lambda x, y: np.power(x, y) * np.log(x), + 0.2, 5.0, -4.0, 4.0] } # Loop over operators for name, op in binary_ops.items():