Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Numpy compatible dstack (#15871)
Browse files Browse the repository at this point in the history
* Add dstack that pass CPU test

Rgister dstack on GPU

Minor comment fix

Minor syntax fix

Syntax fix according to comments

header fix

* Fix sanity
  • Loading branch information
Mike authored and haojin2 committed Oct 11, 2019
1 parent 7f5e687 commit c13806b
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 7 deletions.
46 changes: 44 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
Expand Down Expand Up @@ -2467,6 +2467,48 @@ def get_list(arrays):
return _npi.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
def dstack(arrays):
"""
Stack arrays in sequence depth wise (along third axis).
This is equivalent to concatenation along the third axis after 2-D arrays
of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
`(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
`dsplit`.
This function makes most sense for arrays with up to 3 dimensions. For
instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions `concatenate`, `stack` and
`block` provide more general stacking and concatenation operations.
Parameters
----------
tup : sequence of arrays
The arrays must have the same shape along all but the third axis.
1-D or 2-D arrays must have the same shape.
Returns
-------
stacked : ndarray
The array formed by stacking the given arrays, will be at least 3-D.
Examples
--------
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.dstack((a,b))
array([[[1, 2],
[2, 3],
[3, 4]]])
>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.dstack((a,b))
array([[[1, 2]],
[[2, 3]],
[[3, 4]]])
"""
return _npi.dstack(*arrays)


@set_module('mxnet.ndarray.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
Expand Down
48 changes: 46 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']

Expand Down Expand Up @@ -4010,6 +4010,50 @@ def vstack(arrays, out=None):
return _mx_nd_np.vstack(arrays)


@set_module('mxnet.numpy')
def dstack(arrays):
"""
Stack arrays in sequence depth wise (along third axis).
This is equivalent to concatenation along the third axis after 2-D arrays
of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
`(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
`dsplit`.
This function makes most sense for arrays with up to 3 dimensions. For
instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions `concatenate`, `stack` and
`block` provide more general stacking and concatenation operations.
Parameters
----------
tup : sequence of arrays
The arrays must have the same shape along all but the third axis.
1-D or 2-D arrays must have the same shape.
Returns
-------
stacked : ndarray
The array formed by stacking the given arrays, will be at least 3-D.
Examples
--------
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.dstack((a,b))
array([[[1, 2],
[2, 3],
[3, 4]]])
>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.dstack((a,b))
array([[[1, 2]],
[[2, 3]],
[[3, 4]]])
"""
return _npi.dstack(*arrays)


@set_module('mxnet.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
Expand Down
33 changes: 31 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
Expand Down Expand Up @@ -2661,6 +2661,35 @@ def get_list(arrays):
return _npi.vstack(*arrays)


@set_module('mxnet.symbol.numpy')
def dstack(arrays):
"""
Stack arrays in sequence depth wise (along third axis).
This is equivalent to concatenation along the third axis after 2-D arrays
of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
`(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
`dsplit`.
This function makes most sense for arrays with up to 3 dimensions. For
instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions `concatenate`, `stack` and
`block` provide more general stacking and concatenation operations.
Parameters
----------
tup : sequence of _Symbol
The arrays must have the same shape along all but the first axis.
1-D arrays must have the same length.
Returns
-------
stacked : _Symbol
The array formed by stacking the given arrays, will be at least 2-D.
"""
return _npi.dstack(*arrays)


@set_module('mxnet.symbol.numpy')
def maximum(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
Expand Down
62 changes: 62 additions & 0 deletions src/operator/nn/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = 2;
std::vector<TBlob> modified_inputs(inputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (inputs[i].shape_.ndim() == 0) {
modified_inputs[i] = inputs[i].reshape(TShape(3, 1));
} else if (inputs[i].shape_.ndim() == 1) {
TShape t = TShape(3, 1);
t[1] = inputs[i].shape_[0];
modified_inputs[i] = inputs[i].reshape(t);
} else if (inputs[i].shape_.ndim() == 2) {
TShape t = TShape(3, 1);
t[0] = inputs[i].shape_[0];
t[1] = inputs[i].shape_[1];
modified_inputs[i] = inputs[i].reshape(t);
} else {
modified_inputs[i] = inputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, modified_inputs, req, outputs);
});
}

template<typename xpu>
void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand All @@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = 2;
std::vector<TBlob> modified_outputs(outputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (outputs[i].shape_.ndim() == 0) {
modified_outputs[i] = outputs[i].reshape(TShape(3, 1));
} else if (outputs[i].shape_.ndim() == 1) {
TShape t = TShape(3, 1);
t[1] = outputs[i].shape_[0];
modified_outputs[i] = outputs[i].reshape(t);
} else if (outputs[i].shape_.ndim() == 2) {
TShape t = TShape(3, 1);
t[0] = outputs[i].shape_[0];
t[1] = outputs[i].shape_[1];
modified_outputs[i] = outputs[i].reshape(t);
} else {
modified_outputs[i] = outputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
});
}

/*!
* \brief concat CSRNDArray on the first dimension.
*/
Expand Down
100 changes: 99 additions & 1 deletion src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,67 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape);

bool DStackShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using namespace mshadow;
ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
mxnet::TShape dshape;
dim_t size = 0;
bool has_unknown_dim_size = false;
int axis = 2;
param_.dim = axis;
for (int i = 0; i < param_.num_args; ++i) {
if ((*in_shape)[i].ndim() == 0) {
(*in_shape)[i] = mxnet::TShape(3, 1);
} else if ((*in_shape)[i].ndim() == 1) {
mxnet::TShape t = mxnet::TShape(3, 1);
t[1] = (*in_shape)[i][0];
(*in_shape)[i] = t;
} else if ((*in_shape)[i].ndim() == 2) {
mxnet::TShape t = mxnet::TShape(3, 1);
t[0] = (*in_shape)[i][0];
t[1] = (*in_shape)[i][1];
(*in_shape)[i] = t;
}
mxnet::TShape &tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
CheckAxis(axis, tmp.ndim());
if (!mxnet::dim_size_is_known(tmp, axis)) {
has_unknown_dim_size = true;
} else {
size += tmp[axis];
}
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
}

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == -1) return false;
CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!has_unknown_dim_size) {
dshape[axis] = size;
}
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];

return shape_is_known(dshape);
}

bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);
Expand All @@ -269,7 +330,6 @@ struct NumpyConcatGrad {
}
};


NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand Down Expand Up @@ -490,6 +550,44 @@ NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);

NNVM_REGISTER_OP(_npi_dstack)
.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", DStackShape)
.set_attr<FCompute>("FCompute<cpu>", DStackCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_dstack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_dstack)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", DStackGradCompute<cpu>);

inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down
7 changes: 7 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ NNVM_REGISTER_OP(_npi_vstack)
NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);

NNVM_REGISTER_OP(_npi_dstack)
.set_attr<FCompute>("FCompute<gpu>", DStackCompute<gpu>);

NNVM_REGISTER_OP(_backward_np_dstack)
.set_attr<FCompute>("FCompute<gpu>", DStackGradCompute<gpu>);

NNVM_REGISTER_OP(_np_roll)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollCompute<gpu>);

Expand Down Expand Up @@ -90,5 +96,6 @@ NNVM_REGISTER_OP(_npi_flip)

NNVM_REGISTER_OP(_backward_npi_flip)
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit c13806b

Please sign in to comment.