Skip to content

Commit

Permalink
numpy concatenate (apache#15104)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jul 24, 2019
1 parent b82108a commit e792892
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 16 deletions.
27 changes: 26 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -277,3 +277,28 @@ def argmax(a, axis=None, out=None):
with the dimension along `axis` removed.
"""
return _npi.argmax(a, axis=axis, keepdims=False, out=out)


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)
29 changes: 27 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
'argmax']
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack',
'concatenate', 'arange', 'argmax']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1486,3 +1486,28 @@ def argmax(a, axis=None, out=None):
with the dimension along `axis` removed.
"""
return _mx_nd_np.argmax(a, axis, out)


@set_module('mxnet.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _mx_nd_np.concatenate(seq, axis=axis, out=out)
27 changes: 26 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .._internal import _set_np_symbol_class
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -1060,6 +1060,31 @@ def get_list(arrays):
return _npi.stack(*arrays, axis=axis, out=out)


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.symbol.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
namespace mxnet {
namespace op {

static bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
Expand Down Expand Up @@ -138,9 +138,9 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(dshape);
}

static bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
int dtype = -1;

Expand Down
58 changes: 58 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/

#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -252,5 +253,62 @@ Examples::
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());

bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape);

bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);

struct NumpyConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};


NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_concat"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_concat)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \brief GPU Implementation of numpy matrix operations
*/
#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {
Expand All @@ -36,5 +37,8 @@ NNVM_REGISTER_OP(_np_reshape)
NNVM_REGISTER_OP(_npi_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);

NNVM_REGISTER_OP(_npi_concatenate)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>);

} // namespace op
} // namespace mxnet
12 changes: 6 additions & 6 deletions src/operator/quantization/quantized_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace mxnet {
namespace op {

static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
static bool QuantizedConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args * 3));
CHECK_EQ(out_shape->size(), 3U);
Expand Down Expand Up @@ -74,8 +74,8 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha
return shape_is_known(dshape);
}

static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_type,
std::vector<int>* out_type) {
static bool QuantizedConcatType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_type,
std::vector<int>* out_type) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_type->size(), static_cast<size_t>(param_.num_args * 3));
CHECK_EQ(out_type->size(), 3U);
Expand Down Expand Up @@ -130,8 +130,8 @@ If any input holds int8, then the output will be int8. Otherwise output will be
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedConcatShape)
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,57 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)


@with_seed()
@npx.use_np_shape
def test_np_concat():
class TestConcat(HybridBlock):
def __init__(self, axis=None):
super(TestConcat, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a, *args):
return F.np.concatenate([a] + list(args), axis=self._axis)

def get_new_shape(shape, axis):
shape_lst = list(shape)
shape_lst[axis] = random.randint(0, 3)
return tuple(shape_lst)

for shape in [(0, 0), (2, 3)]:
for hybridize in [True, False]:
for axis in range(2):
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
d.attach_grad()
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

y.backward()

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit e792892

Please sign in to comment.