diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index aea8b19c2913..2846d2bac5ce 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,8 +32,8 @@ 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack', + 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] @@ -2467,6 +2467,48 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.ndarray.numpy') +def dstack(arrays): + """ + Stack arrays in sequence depth wise (along third axis). + This is equivalent to concatenation along the third axis after 2-D arrays + of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape + `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by + `dsplit`. + This function makes most sense for arrays with up to 3 dimensions. For + instance, for pixel-data with a height (first axis), width (second axis), + and r/g/b channels (third axis). The functions `concatenate`, `stack` and + `block` provide more general stacking and concatenation operations. + + Parameters + ---------- + tup : sequence of arrays + The arrays must have the same shape along all but the third axis. + 1-D or 2-D arrays must have the same shape. + + Returns + ------- + stacked : ndarray + The array formed by stacking the given arrays, will be at least 3-D. + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.dstack((a,b)) + array([[[1, 2], + [2, 3], + [3, 4]]]) + >>> a = np.array([[1],[2],[3]]) + >>> b = np.array([[2],[3],[4]]) + >>> np.dstack((a,b)) + array([[[1, 2]], + [[2, 3]], + [[3, 4]]]) + """ + return _npi.dstack(*arrays) + + @set_module('mxnet.ndarray.numpy') def maximum(x1, x2, out=None): """Returns element-wise maximum of the input arrays with broadcasting. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index d3ae4d19aca8..00a770930529 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -51,8 +51,8 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', - 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', + 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', + 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] @@ -4010,6 +4010,50 @@ def vstack(arrays, out=None): return _mx_nd_np.vstack(arrays) +@set_module('mxnet.numpy') +def dstack(arrays): + """ + Stack arrays in sequence depth wise (along third axis). + + This is equivalent to concatenation along the third axis after 2-D arrays + of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape + `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by + `dsplit`. + + This function makes most sense for arrays with up to 3 dimensions. For + instance, for pixel-data with a height (first axis), width (second axis), + and r/g/b channels (third axis). The functions `concatenate`, `stack` and + `block` provide more general stacking and concatenation operations. + + Parameters + ---------- + tup : sequence of arrays + The arrays must have the same shape along all but the third axis. + 1-D or 2-D arrays must have the same shape. + + Returns + ------- + stacked : ndarray + The array formed by stacking the given arrays, will be at least 3-D. + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.dstack((a,b)) + array([[[1, 2], + [2, 3], + [3, 4]]]) + >>> a = np.array([[1],[2],[3]]) + >>> b = np.array([[2],[3],[4]]) + >>> np.dstack((a,b)) + array([[[1, 2]], + [[2, 3]], + [[3, 4]]]) + """ + return _npi.dstack(*arrays) + + @set_module('mxnet.numpy') def maximum(x1, x2, out=None): """Returns element-wise maximum of the input arrays with broadcasting. diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 9a909420e934..de11cfbc01a5 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,8 +34,8 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack', + 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] @@ -2661,6 +2661,35 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.symbol.numpy') +def dstack(arrays): + """ + Stack arrays in sequence depth wise (along third axis). + + This is equivalent to concatenation along the third axis after 2-D arrays + of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape + `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by + `dsplit`. + + This function makes most sense for arrays with up to 3 dimensions. For + instance, for pixel-data with a height (first axis), width (second axis), + and r/g/b channels (third axis). The functions `concatenate`, `stack` and + `block` provide more general stacking and concatenation operations. + + Parameters + ---------- + tup : sequence of _Symbol + The arrays must have the same shape along all but the first axis. + 1-D arrays must have the same length. + + Returns + ------- + stacked : _Symbol + The array formed by stacking the given arrays, will be at least 2-D. + """ + return _npi.dstack(*arrays) + + @set_module('mxnet.symbol.numpy') def maximum(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h index 7a58ae6f0ccc..1fb20ac65473 100644 --- a/src/operator/nn/concat-inl.h +++ b/src/operator/nn/concat-inl.h @@ -141,6 +141,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, }); } +template +void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ConcatParam param = nnvm::get(attrs.parsed); + param.dim = 2; + std::vector modified_inputs(inputs.size()); + for (int i = 0; i < param.num_args; ++i) { + if (inputs[i].shape_.ndim() == 0) { + modified_inputs[i] = inputs[i].reshape(TShape(3, 1)); + } else if (inputs[i].shape_.ndim() == 1) { + TShape t = TShape(3, 1); + t[1] = inputs[i].shape_[0]; + modified_inputs[i] = inputs[i].reshape(t); + } else if (inputs[i].shape_.ndim() == 2) { + TShape t = TShape(3, 1); + t[0] = inputs[i].shape_[0]; + t[1] = inputs[i].shape_[1]; + modified_inputs[i] = inputs[i].reshape(t); + } else { + modified_inputs[i] = inputs[i]; + } + } + MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, { + ConcatOp op; + op.Init(param); + op.Forward(ctx, modified_inputs, req, outputs); + }); +} + template void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, }); } +template +void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ConcatParam param = nnvm::get(attrs.parsed); + param.dim = 2; + std::vector modified_outputs(outputs.size()); + for (int i = 0; i < param.num_args; ++i) { + if (outputs[i].shape_.ndim() == 0) { + modified_outputs[i] = outputs[i].reshape(TShape(3, 1)); + } else if (outputs[i].shape_.ndim() == 1) { + TShape t = TShape(3, 1); + t[1] = outputs[i].shape_[0]; + modified_outputs[i] = outputs[i].reshape(t); + } else if (outputs[i].shape_.ndim() == 2) { + TShape t = TShape(3, 1); + t[0] = outputs[i].shape_[0]; + t[1] = outputs[i].shape_[1]; + modified_outputs[i] = outputs[i].reshape(t); + } else { + modified_outputs[i] = outputs[i]; + } + } + MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, { + ConcatOp op; + op.Init(param); + op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs); + }); +} + /*! * \brief concat CSRNDArray on the first dimension. */ diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 96a10561be28..38044686fe6a 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -255,6 +255,67 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape); +bool DStackShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + using namespace mshadow; + ConcatParam param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + mxnet::TShape dshape; + dim_t size = 0; + bool has_unknown_dim_size = false; + int axis = 2; + param_.dim = axis; + for (int i = 0; i < param_.num_args; ++i) { + if ((*in_shape)[i].ndim() == 0) { + (*in_shape)[i] = mxnet::TShape(3, 1); + } else if ((*in_shape)[i].ndim() == 1) { + mxnet::TShape t = mxnet::TShape(3, 1); + t[1] = (*in_shape)[i][0]; + (*in_shape)[i] = t; + } else if ((*in_shape)[i].ndim() == 2) { + mxnet::TShape t = mxnet::TShape(3, 1); + t[0] = (*in_shape)[i][0]; + t[1] = (*in_shape)[i][1]; + (*in_shape)[i] = t; + } + mxnet::TShape &tmp = (*in_shape)[i]; + if (tmp.ndim() > 0) { + CheckAxis(axis, tmp.ndim()); + if (!mxnet::dim_size_is_known(tmp, axis)) { + has_unknown_dim_size = true; + } else { + size += tmp[axis]; + } + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + } + + mxnet::TShape tmp = (*out_shape)[0]; + if (tmp.ndim() > 0) { + axis = CheckAxis(param_.dim, tmp.ndim()); + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == -1) return false; + CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated"; + + for (int i = 0; i < param_.num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!has_unknown_dim_size) { + dshape[axis] = size; + } + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + + return shape_is_known(dshape); +} + bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type); @@ -269,7 +330,6 @@ struct NumpyConcatGrad { } }; - NNVM_REGISTER_OP(_npi_concatenate) .describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { @@ -490,6 +550,44 @@ NNVM_REGISTER_OP(_backward_np_vstack) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyVstackBackward); +NNVM_REGISTER_OP(_npi_dstack) +.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("data") + std::to_string(i)); + } + return ret; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"out"}; +}) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferType", ConcatType) +.set_attr("FInferShape", DStackShape) +.set_attr("FCompute", DStackCompute) +.set_attr("FGradient", NumpyConcatGrad{"_backward_np_dstack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_dstack) +.set_num_outputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", DStackGradCompute); + inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index caab4108b40e..125cd91acd1f 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -53,6 +53,12 @@ NNVM_REGISTER_OP(_npi_vstack) NNVM_REGISTER_OP(_backward_np_vstack) .set_attr("FCompute", NumpyVstackBackward); +NNVM_REGISTER_OP(_npi_dstack) +.set_attr("FCompute", DStackCompute); + +NNVM_REGISTER_OP(_backward_np_dstack) +.set_attr("FCompute", DStackGradCompute); + NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); @@ -90,5 +96,6 @@ NNVM_REGISTER_OP(_npi_flip) NNVM_REGISTER_OP(_backward_npi_flip) .set_attr("FCompute", NumpyFlipForward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1f90f30bd8c3..e9fa84e0cd7b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1685,6 +1685,67 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_dstack(): + class TestDStack(HybridBlock): + def __init__(self): + super(TestDStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.dstack([a] + list(args)) + + def get_new_shape(shape): + if len(shape) < 3: + return shape + axis = 2 + shape_lst = list(shape) + shape_lst[axis] = random.randint(0, 5) + return tuple(shape_lst) + + shapes = [ + (), + (1,), + (2,1), + (2,2,4), + (2,0,0), + (0,1,3), + (2,0,3), + (2,3,4,5) + ] + for hybridize in [True, False]: + for shape in shapes: + test_dstack = TestDStack() + if hybridize: + test_dstack.hybridize() + # test symbolic forward + a = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + a.attach_grad() + b = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + b.attach_grad() + c = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + c.attach_grad() + d = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray() + d.attach_grad() + with mx.autograd.record(): + mx_out = test_dstack(a, b, c, d) + np_out = _np.dstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy())) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # test symbolic backward + mx_out.backward() + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.dstack((a, b, c, d)) + np_out = _np.dstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy())) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + @with_seed() @use_np def test_np_ravel():