From 9c5d3645a92ee845aa8d217b65b3262f627d2da7 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Fri, 21 Jun 2019 09:09:36 +0000 Subject: [PATCH 1/2] Add dstack that pass CPU test Rgister dstack on GPU Minor comment fix Minor syntax fix Syntax fix according to comments header fix --- python/mxnet/ndarray/numpy/_op.py | 18 ++++- python/mxnet/numpy/multiarray.py | 18 ++++- python/mxnet/symbol/numpy/_symbol.py | 16 +++- src/operator/nn/concat-inl.h | 62 +++++++++++++++ src/operator/numpy/np_matrix_op.cc | 100 ++++++++++++++++++++++++- src/operator/numpy/np_matrix_op.cu | 6 ++ tests/python/unittest/test_numpy_op.py | 61 +++++++++++++++ 7 files changed, 277 insertions(+), 4 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7f3fd1ace54..7686dbb113d7 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,7 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'dstack'] @set_module('mxnet.ndarray.numpy') @@ -200,6 +200,22 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou #pylint: enable= too-many-arguments, no-member, protected-access +@set_module('mxnet.ndarray.numpy') +def dstack(arrays): + """Stack tensors in sequence depth wise. + This is equivalent to concatenation along the third axis, except for zero + dimensional, 1-D or 2D tensors, in which case the first dimension is used. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + Returns + ------- + depth-wisely concatenated ndarray + """ + return _npi.dstack(*arrays) + + @set_module('mxnet.ndarray.numpy') def add(x1, x2, out=None): """Add arguments element-wise. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8988b4eb19c9..ae079bc1a060 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', - 'concatenate'] + 'concatenate', 'dstack'] # This function is copied from ndarray.py since pylint @@ -1877,3 +1877,19 @@ def concatenate(seq, axis=0, out=None): The concatenated array. """ return _mx_nd_np.concatenate(seq, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def dstack(arrays): + """Stack tensors in sequence depth wise. + This is equivalent to concatenation along the third axis, except for zero + dimensional, 1-D or 2D tensors, in which case the first dimension is used. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + Returns + ------- + depth-wisely concatenated ndarray + """ + return _npi.dstack(*arrays) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index a6699d60871a..914b2bd323f8 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'dstack'] def _num_outputs(sym): @@ -1135,6 +1135,20 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis else: return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype) +@set_module('mxnet.symbol.numpy') +def dstack(arrays): + """Stack tensors in sequence depth wise. + This is equivalent to concatenation along the third axis, except for zero + dimensional, 1-D or 2D tensors, in which case the first dimension is used. + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + Returns + ------- + depth-wisely concatenated ndarray + """ + return _npi.dstack(*arrays) @set_module('mxnet.symbol.numpy') def expand_dims(a, axis): diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h index 7a58ae6f0ccc..1fb20ac65473 100644 --- a/src/operator/nn/concat-inl.h +++ b/src/operator/nn/concat-inl.h @@ -141,6 +141,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, }); } +template +void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ConcatParam param = nnvm::get(attrs.parsed); + param.dim = 2; + std::vector modified_inputs(inputs.size()); + for (int i = 0; i < param.num_args; ++i) { + if (inputs[i].shape_.ndim() == 0) { + modified_inputs[i] = inputs[i].reshape(TShape(3, 1)); + } else if (inputs[i].shape_.ndim() == 1) { + TShape t = TShape(3, 1); + t[1] = inputs[i].shape_[0]; + modified_inputs[i] = inputs[i].reshape(t); + } else if (inputs[i].shape_.ndim() == 2) { + TShape t = TShape(3, 1); + t[0] = inputs[i].shape_[0]; + t[1] = inputs[i].shape_[1]; + modified_inputs[i] = inputs[i].reshape(t); + } else { + modified_inputs[i] = inputs[i]; + } + } + MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, { + ConcatOp op; + op.Init(param); + op.Forward(ctx, modified_inputs, req, outputs); + }); +} + template void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, }); } +template +void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ConcatParam param = nnvm::get(attrs.parsed); + param.dim = 2; + std::vector modified_outputs(outputs.size()); + for (int i = 0; i < param.num_args; ++i) { + if (outputs[i].shape_.ndim() == 0) { + modified_outputs[i] = outputs[i].reshape(TShape(3, 1)); + } else if (outputs[i].shape_.ndim() == 1) { + TShape t = TShape(3, 1); + t[1] = outputs[i].shape_[0]; + modified_outputs[i] = outputs[i].reshape(t); + } else if (outputs[i].shape_.ndim() == 2) { + TShape t = TShape(3, 1); + t[0] = outputs[i].shape_[0]; + t[1] = outputs[i].shape_[1]; + modified_outputs[i] = outputs[i].reshape(t); + } else { + modified_outputs[i] = outputs[i]; + } + } + MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, { + ConcatOp op; + op.Init(param); + op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs); + }); +} + /*! * \brief concat CSRNDArray on the first dimension. */ diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 73340981037d..764a379a527c 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -251,6 +251,67 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape); +bool DStackShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + using namespace mshadow; + ConcatParam param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + mxnet::TShape dshape; + dim_t size = 0; + bool has_unknown_dim_size = false; + int axis = 2; + param_.dim = axis; + for (int i = 0; i < param_.num_args; ++i) { + if ((*in_shape)[i].ndim() == 0) { + (*in_shape)[i] = mxnet::TShape(3, 1); + } else if ((*in_shape)[i].ndim() == 1) { + mxnet::TShape t = mxnet::TShape(3, 1); + t[1] = (*in_shape)[i][0]; + (*in_shape)[i] = t; + } else if ((*in_shape)[i].ndim() == 2) { + mxnet::TShape t = mxnet::TShape(3, 1); + t[0] = (*in_shape)[i][0]; + t[1] = (*in_shape)[i][1]; + (*in_shape)[i] = t; + } + mxnet::TShape &tmp = (*in_shape)[i]; + if (tmp.ndim() > 0) { + CheckAxis(axis, tmp.ndim()); + if (!mxnet::dim_size_is_known(tmp, axis)) { + has_unknown_dim_size = true; + } else { + size += tmp[axis]; + } + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + } + + mxnet::TShape tmp = (*out_shape)[0]; + if (tmp.ndim() > 0) { + axis = CheckAxis(param_.dim, tmp.ndim()); + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == -1) return false; + CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated"; + + for (int i = 0; i < param_.num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!has_unknown_dim_size) { + dshape[axis] = size; + } + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + + return shape_is_known(dshape); +} + bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type); @@ -265,7 +326,6 @@ struct NumpyConcatGrad { } }; - NNVM_REGISTER_OP(_npi_concatenate) .describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { @@ -295,6 +355,35 @@ NNVM_REGISTER_OP(_npi_concatenate) .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") .add_arguments(ConcatParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_dstack) +.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("data") + std::to_string(i)); + } + return ret; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"out"}; +}) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferType", ConcatType) +.set_attr("FInferShape", DStackShape) +.set_attr("FCompute", DStackCompute) +.set_attr("FGradient", NumpyConcatGrad{"_backward_np_dstack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + NNVM_REGISTER_OP(_backward_np_concat) .set_num_outputs([](const NodeAttrs& attrs) { const ConcatParam& params = nnvm::get(attrs.parsed); @@ -304,5 +393,14 @@ NNVM_REGISTER_OP(_backward_np_concat) .set_attr("TIsBackward", true) .set_attr("FCompute", ConcatGradCompute); +NNVM_REGISTER_OP(_backward_np_dstack) +.set_num_outputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", DStackGradCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index f192560f4ac9..b6ba22715cf8 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -41,6 +41,12 @@ NNVM_REGISTER_OP(_npi_concatenate) NNVM_REGISTER_OP(_backward_np_concat) .set_attr("FCompute", ConcatGradCompute); + +NNVM_REGISTER_OP(_npi_dstack) +.set_attr("FCompute", DStackCompute); + +NNVM_REGISTER_OP(_backward_np_dstack) +.set_attr("FCompute", DStackGradCompute); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2291bcdb6d3d..0a7f91853cde 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -880,6 +880,67 @@ def get_new_shape(shape, axis): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_dstack(): + class TestDStack(HybridBlock): + def __init__(self): + super(TestDStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.dstack([a] + list(args)) + + def get_new_shape(shape): + if len(shape) < 3: + return shape + axis = 2 + shape_lst = list(shape) + shape_lst[axis] = random.randint(0, 5) + return tuple(shape_lst) + + shapes = [ + (), + (1,), + (2,1), + (2,2,4), + (2,0,0), + (0,1,3), + (2,0,3), + (2,3,4,5) + ] + for hybridize in [True, False]: + for shape in shapes: + test_dstack = TestDStack() + if hybridize: + test_dstack.hybridize() + # test symbolic forward + a = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + a.attach_grad() + b = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + b.attach_grad() + c = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + c.attach_grad() + d = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + d.attach_grad() + with mx.autograd.record(): + mx_out = test_dstack(a, b, c, d) + np_out = _np.dstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy())) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # test symbolic backward + mx_out.backward() + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.dstack((a, b, c, d)) + np_out = _np.dstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy())) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule() From a97d173b9e670c83a64f0ca9605ce94712c309d1 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Thu, 15 Aug 2019 09:18:43 +0000 Subject: [PATCH 2/2] Fix sanity --- src/operator/numpy/np_matrix_op.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index b6ba22715cf8..28df011b2ae3 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -41,7 +41,7 @@ NNVM_REGISTER_OP(_npi_concatenate) NNVM_REGISTER_OP(_backward_np_concat) .set_attr("FCompute", ConcatGradCompute); - + NNVM_REGISTER_OP(_npi_dstack) .set_attr("FCompute", DStackCompute);