Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy-compatible concatenate upstream #15843

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/jenkins/Jenkinsfile_unix_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/

// timeout in minutes
max_time = 180
max_time = 240

node('utility') {
// Loading the utilities requires a node context unfortunately
Expand Down
25 changes: 24 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -682,3 +682,26 @@ def split(ary, indices_or_sections, axis=0):
if not isinstance(ret, list):
return [ret]
return ret


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)
26 changes: 25 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split']
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split',
'concatenate']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1853,3 +1854,26 @@ def split(ary, indices_or_sections, axis=0):
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _mx_nd_np.concatenate(seq, axis=axis, out=out)
25 changes: 24 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']


def _num_outputs(sym):
Expand Down Expand Up @@ -1312,4 +1312,27 @@ def split(ary, indices_or_sections, axis=0):
return ret


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)


_set_np_symbol_class(_Symbol)
1 change: 1 addition & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <vector>
#include "../tensor/matrix_op-inl.h"
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {
Expand Down
58 changes: 57 additions & 1 deletion src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/

#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -248,5 +247,62 @@ NNVM_REGISTER_OP(_np_squeeze)
.add_argument("a", "NDArray-or-Symbol[]", "data to squeeze")
.add_arguments(SqueezeParam::__FIELDS__());

bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape);

bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);

struct NumpyConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};


NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_concat"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_concat)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

} // namespace op
} // namespace mxnet
7 changes: 6 additions & 1 deletion src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
* \brief GPU Implementation of numpy matrix operations
*/
#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {
Expand All @@ -37,5 +36,11 @@ NNVM_REGISTER_OP(_np_reshape)
NNVM_REGISTER_OP(_np_squeeze)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

NNVM_REGISTER_OP(_npi_concatenate)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>);

NNVM_REGISTER_OP(_backward_np_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

} // namespace op
} // namespace mxnet
51 changes: 51 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,57 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_concat():
class TestConcat(HybridBlock):
def __init__(self, axis=None):
super(TestConcat, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a, *args):
return F.np.concatenate([a] + list(args), axis=self._axis)

def get_new_shape(shape, axis):
shape_lst = list(shape)
shape_lst[axis] = random.randint(0, 3)
return tuple(shape_lst)

for shape in [(0, 0), (2, 3)]:
for hybridize in [True, False]:
for axis in range(2):
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
d.attach_grad()
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

y.backward()

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()