Skip to content

Commit

Permalink
Numpy-compatible stack upstream (apache#15842)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and reminisce committed Aug 18, 2019
1 parent 1a6fe60 commit 8eb0f61
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 3 deletions.
29 changes: 28 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -705,3 +705,30 @@ def concatenate(seq, axis=0, out=None):
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.ndarray.numpy')
def stack(arrays, axis=0, out=None):
"""Join a sequence of arrays along a new axis.
The axis parameter specifies the index of the new axis in the dimensions of the result.
For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
out : ndarray, optional
If provided, the destination to place the result. The shape must be correct,
matching that of what stack would have returned if no out argument were specified.
Returns
-------
stacked : ndarray
The stacked array has one more dimension than the input arrays."""
def get_list(arrays):
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.stack(*arrays, axis=axis, out=out)
23 changes: 22 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split',
'concatenate']
'concatenate', 'stack']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1877,3 +1877,24 @@ def concatenate(seq, axis=0, out=None):
The concatenated array.
"""
return _mx_nd_np.concatenate(seq, axis=axis, out=out)


@set_module('mxnet.numpy')
def stack(arrays, axis=0, out=None):
"""Join a sequence of arrays along a new axis.
The axis parameter specifies the index of the new axis in the dimensions of the result.
For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
out : ndarray, optional
If provided, the destination to place the result. The shape must be correct,
matching that of what stack would have returned if no out argument were specified.
Returns
-------
stacked : ndarray
The stacked array has one more dimension than the input arrays."""
return _mx_nd_np.stack(arrays, axis=axis, out=out)
29 changes: 28 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']


def _num_outputs(sym):
Expand Down Expand Up @@ -1335,4 +1335,31 @@ def concatenate(seq, axis=0, out=None):
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.symbol.numpy')
def stack(arrays, axis=0, out=None):
"""Join a sequence of arrays along a new axis.
The axis parameter specifies the index of the new axis in the dimensions of the result.
For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
out : ndarray, optional
If provided, the destination to place the result. The shape must be correct,
matching that of what stack would have returned if no out argument were specified.
Returns
-------
stacked : ndarray
The stacked array has one more dimension than the input arrays."""
def get_list(arrays):
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.stack(*arrays, axis=axis, out=out)


_set_np_symbol_class(_Symbol)
41 changes: 41 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \brief CPU Implementation of numpy matrix operations
*/

#include <vector>
#include "./np_matrix_op-inl.h"

namespace mxnet {
Expand Down Expand Up @@ -304,5 +305,45 @@ NNVM_REGISTER_OP(_backward_np_concat)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

NNVM_REGISTER_OP(_npi_stack)
.describe(R"code(Join a sequence of arrays along a new axis.
The axis parameter specifies the index of the new axis in the dimensions of the
result. For example, if axis=0 it will be the first dimension and if axis=-1 it
will be the last dimension.
Examples::
x = [1, 2]
y = [3, 4]
stack(x, y) = [[1, 2],
[3, 4]]
stack(x, y, axis=1) = [[1, 3],
[2, 4]]
)code")
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
return static_cast<uint32_t>(param.num_args);
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<StackParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args = dmlc::get<StackParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<mxnet::FInferShape>("FInferShape", StackOpShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FCompute>("FCompute<cpu>", StackOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_stack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \file np_matrix_op.cu
* \brief GPU Implementation of numpy matrix operations
*/

#include "./np_matrix_op-inl.h"

namespace mxnet {
Expand All @@ -42,5 +43,8 @@ NNVM_REGISTER_OP(_npi_concatenate)
NNVM_REGISTER_OP(_backward_np_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

NNVM_REGISTER_OP(_npi_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);

} // namespace op
} // namespace mxnet
51 changes: 51 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ def get_new_shape(shape, axis):
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)

assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

Expand All @@ -880,6 +881,56 @@ def get_new_shape(shape, axis):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_stack():
class TestStack(HybridBlock):
def __init__(self, axis=None):
super(TestStack, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a, *args):
return F.np.stack([a] + list(args), axis=self._axis)

a, b, c, d = mx.sym.Variable("a"), mx.sym.Variable("b"), mx.sym.Variable("c"), mx.sym.Variable("d")
ret = mx.sym.np.stack([a.as_np_ndarray(), b.as_np_ndarray(), c.as_np_ndarray(), d.as_np_ndarray()])
assert type(ret) == mx.sym.np._Symbol

for shape in [(0, 0), (2, 3)]:
for hybridize in [True, False]:
for axis in range(2):
test_stack = TestStack(axis=axis)
if hybridize:
test_stack.hybridize()
np_a = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
np_b = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
np_c = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
np_d = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)

mx_a = np.array(np_a)
mx_a.attach_grad()
mx_b = np.array(np_b)
mx_b.attach_grad()
mx_c = np.array(np_c)
mx_c.attach_grad()
mx_d = np.array(np_d)
mx_d.attach_grad()
expected_ret = _np.stack([np_a, np_b, np_c, np_d], axis=axis)
with mx.autograd.record():
y = test_stack(mx_a, mx_b, mx_c, mx_d)

y.backward()

assert_almost_equal(mx_a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)

np_out = _np.stack([np_a, np_b, np_c, np_d], axis=axis)
mx_out = np.stack([mx_a, mx_b, mx_c, mx_d], axis=axis)
assert same(mx_out.asnumpy(), np_out)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 8eb0f61

Please sign in to comment.