diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7f3fd1ace54..a129f738c566 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,7 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] @set_module('mxnet.ndarray.numpy') @@ -705,3 +705,30 @@ def concatenate(seq, axis=0, out=None): The concatenated array. """ return _npi.concatenate(*seq, dim=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def stack(arrays, axis=0, out=None): + """Join a sequence of arrays along a new axis. + The axis parameter specifies the index of the new axis in the dimensions of the result. + For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + axis : int, optional + The axis in the result array along which the input arrays are stacked. + out : ndarray, optional + If provided, the destination to place the result. The shape must be correct, + matching that of what stack would have returned if no out argument were specified. + Returns + ------- + stacked : ndarray + The stacked array has one more dimension than the input arrays.""" + def get_list(arrays): + if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'): + raise ValueError("expected iterable for arrays but got {}".format(type(arrays))) + return [arr for arr in arrays] + + arrays = get_list(arrays) + return _npi.stack(*arrays, axis=axis, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8988b4eb19c9..316d88ae0a81 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', - 'concatenate'] + 'concatenate', 'stack'] # This function is copied from ndarray.py since pylint @@ -1877,3 +1877,24 @@ def concatenate(seq, axis=0, out=None): The concatenated array. """ return _mx_nd_np.concatenate(seq, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def stack(arrays, axis=0, out=None): + """Join a sequence of arrays along a new axis. + The axis parameter specifies the index of the new axis in the dimensions of the result. + For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + axis : int, optional + The axis in the result array along which the input arrays are stacked. + out : ndarray, optional + If provided, the destination to place the result. The shape must be correct, + matching that of what stack would have returned if no out argument were specified. + Returns + ------- + stacked : ndarray + The stacked array has one more dimension than the input arrays.""" + return _mx_nd_np.stack(arrays, axis=axis, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index a6699d60871a..5537e6305b83 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] def _num_outputs(sym): @@ -1335,4 +1335,31 @@ def concatenate(seq, axis=0, out=None): return _npi.concatenate(*seq, dim=axis, out=out) +@set_module('mxnet.symbol.numpy') +def stack(arrays, axis=0, out=None): + """Join a sequence of arrays along a new axis. + The axis parameter specifies the index of the new axis in the dimensions of the result. + For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + axis : int, optional + The axis in the result array along which the input arrays are stacked. + out : ndarray, optional + If provided, the destination to place the result. The shape must be correct, + matching that of what stack would have returned if no out argument were specified. + Returns + ------- + stacked : ndarray + The stacked array has one more dimension than the input arrays.""" + def get_list(arrays): + if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'): + raise ValueError("expected iterable for arrays but got {}".format(type(arrays))) + return [arr for arr in arrays] + + arrays = get_list(arrays) + return _npi.stack(*arrays, axis=axis, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 73340981037d..5ad6c8908017 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -23,6 +23,7 @@ * \brief CPU Implementation of numpy matrix operations */ +#include #include "./np_matrix_op-inl.h" namespace mxnet { @@ -304,5 +305,45 @@ NNVM_REGISTER_OP(_backward_np_concat) .set_attr("TIsBackward", true) .set_attr("FCompute", ConcatGradCompute); +NNVM_REGISTER_OP(_npi_stack) +.describe(R"code(Join a sequence of arrays along a new axis. + +The axis parameter specifies the index of the new axis in the dimensions of the +result. For example, if axis=0 it will be the first dimension and if axis=-1 it +will be the last dimension. + +Examples:: + + x = [1, 2] + y = [3, 4] + + stack(x, y) = [[1, 2], + [3, 4]] + stack(x, y, axis=1) = [[1, 3], + [2, 4]] +)code") +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const StackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); + }) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_args; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; + }) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferShape", StackOpShape) +.set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FCompute", StackOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_stack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") +.add_arguments(StackParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index f192560f4ac9..4ba527deca09 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -22,6 +22,7 @@ * \file np_matrix_op.cu * \brief GPU Implementation of numpy matrix operations */ + #include "./np_matrix_op-inl.h" namespace mxnet { @@ -42,5 +43,8 @@ NNVM_REGISTER_OP(_npi_concatenate) NNVM_REGISTER_OP(_backward_np_concat) .set_attr("FCompute", ConcatGradCompute); +NNVM_REGISTER_OP(_npi_stack) +.set_attr("FCompute", StackOpForward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2291bcdb6d3d..6a1a6a6dfd5d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -864,6 +864,7 @@ def get_new_shape(shape, axis): expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) with mx.autograd.record(): y = test_concat(a, b, c, d) + assert y.shape == expected_ret.shape assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) @@ -880,6 +881,56 @@ def get_new_shape(shape, axis): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_stack(): + class TestStack(HybridBlock): + def __init__(self, axis=None): + super(TestStack, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, a, *args): + return F.np.stack([a] + list(args), axis=self._axis) + + a, b, c, d = mx.sym.Variable("a"), mx.sym.Variable("b"), mx.sym.Variable("c"), mx.sym.Variable("d") + ret = mx.sym.np.stack([a.as_np_ndarray(), b.as_np_ndarray(), c.as_np_ndarray(), d.as_np_ndarray()]) + assert type(ret) == mx.sym.np._Symbol + + for shape in [(0, 0), (2, 3)]: + for hybridize in [True, False]: + for axis in range(2): + test_stack = TestStack(axis=axis) + if hybridize: + test_stack.hybridize() + np_a = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32) + np_b = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32) + np_c = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32) + np_d = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32) + + mx_a = np.array(np_a) + mx_a.attach_grad() + mx_b = np.array(np_b) + mx_b.attach_grad() + mx_c = np.array(np_c) + mx_c.attach_grad() + mx_d = np.array(np_d) + mx_d.attach_grad() + expected_ret = _np.stack([np_a, np_b, np_c, np_d], axis=axis) + with mx.autograd.record(): + y = test_stack(mx_a, mx_b, mx_c, mx_d) + + y.backward() + + assert_almost_equal(mx_a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5) + + np_out = _np.stack([np_a, np_b, np_c, np_d], axis=axis) + mx_out = np.stack([mx_a, mx_b, mx_c, mx_d], axis=axis) + assert same(mx_out.asnumpy(), np_out) + + if __name__ == '__main__': import nose nose.runmodule()