From 6c42992f2a12c253ab92fa7a85556de5da357f89 Mon Sep 17 00:00:00 2001 From: JiangZhaoh <54654391+JiangZhaoh@users.noreply.github.com> Date: Fri, 1 Nov 2019 16:20:06 +0800 Subject: [PATCH] [numpy] add numpy operator : append (#16564) * add operator : append ; fix op concatenate when axis = None * pylint disable remove mistake disable pylint --- python/mxnet/ndarray/numpy/_op.py | 62 ++++++++- python/mxnet/numpy/multiarray.py | 47 ++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 64 ++++++++- src/operator/numpy/np_matrix_op-inl.h | 81 +++++++++++ src/operator/numpy/np_matrix_op.cc | 105 +++++++++++++-- src/operator/numpy/np_matrix_op.cu | 4 +- .../unittest/test_numpy_interoperability.py | 20 +++ tests/python/unittest/test_numpy_op.py | 127 ++++++++++++++---- 9 files changed, 458 insertions(+), 53 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 256cfb7d5708..c215159edb5e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -33,7 +33,7 @@ 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', - 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', + 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', @@ -2919,8 +2919,64 @@ def concatenate(seq, axis=0, out=None): ------- res : ndarray The concatenated array. + + Examples + -------- + >>> a = np.array([[1, 2], [3, 4]]) + >>> b = np.array([[5, 6]]) + >>> np.concatenate((a, b), axis=0) + array([[1., 2.], + [3., 4.], + [5., 6.]]) + + >>> np.concatenate((a, b), axis=None) + array([1., 2., 3., 4., 5., 6.]) + + >>> np.concatenate((a, b.T), axis=1) + array([[1., 2., 5.], + [3., 4., 6.]]) + """ + return _npi.concatenate(*seq, axis=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def append(arr, values, axis=None): # pylint: disable=redefined-outer-name + """ + Append values to the end of an array. + + Parameters + ---------- + arr : ndarray + Values are appended to a copy of this array. + values : ndarray + These values are appended to a copy of `arr`. It must be of the + correct shape (the same shape as `arr`, excluding `axis`). If + `axis` is not specified, `values` can be any shape and will be + flattened before use. + axis : int, optional + The axis along which `values` are appended. If `axis` is not + given, both `arr` and `values` are flattened before use. + + Returns + ------- + append : ndarray + A copy of `arr` with `values` appended to `axis`. Note that + `append` does not occur in-place: a new array is allocated and + filled. If `axis` is None, `out` is a flattened array. + + Examples + -------- + >>> np.append(np.array([1, 2, 3]), np.array([[4, 5, 6],[7, 8, 9]])) + array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + + When `axis` is specified, `values` must have the correct shape. + + >>> np.append(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9]]), axis=0) + array([[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.]]) """ - return _npi.concatenate(*seq, dim=axis, out=out) + return _npi.concatenate(arr, values, axis=axis, out=None) @set_module('mxnet.ndarray.numpy') @@ -5014,7 +5070,7 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b).item() -def diff(a, n=1, axis=-1, prepend=None, append=None): +def diff(a, n=1, axis=-1, prepend=None, append=None): # pylint: disable=redefined-outer-name r""" numpy.diff(a, n=1, axis=-1, prepend=, append=) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8e0d5b209a8d..85bd2ac0e2b6 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', - 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', + 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', @@ -4803,10 +4803,53 @@ def concatenate(seq, axis=0, out=None): >>> np.concatenate((a, b.T), axis=1) array([[1., 2., 5.], [3., 4., 6.]]) + + >>> np.concatenate((a, b), axis=None) + array([1., 2., 3., 4., 5., 6.]) """ return _mx_nd_np.concatenate(seq, axis=axis, out=out) +@set_module('mxnet.numpy') +def append(arr, values, axis=None): # pylint: disable=redefined-outer-name + """ + Append values to the end of an array. + + Parameters + ---------- + arr : ndarray + Values are appended to a copy of this array. + values : ndarray + These values are appended to a copy of `arr`. It must be of the + correct shape (the same shape as `arr`, excluding `axis`). If + `axis` is not specified, `values` can be any shape and will be + flattened before use. + axis : int, optional + The axis along which `values` are appended. If `axis` is not + given, both `arr` and `values` are flattened before use. + + Returns + ------- + append : ndarray + A copy of `arr` with `values` appended to `axis`. Note that + `append` does not occur in-place: a new array is allocated and + filled. If `axis` is None, `out` is a flattened array. + + Examples + -------- + >>> np.append(np.array([1, 2, 3]), np.array([[4, 5, 6],[7, 8, 9]])) + array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + + When `axis` is specified, `values` must have the correct shape. + + >>> np.append(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9]]), axis=0) + array([[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.]]) + """ + return _mx_nd_np.append(arr, values, axis=axis) + + @set_module('mxnet.numpy') def stack(arrays, axis=0, out=None): """Join a sequence of arrays along a new axis. @@ -7018,7 +7061,7 @@ def may_share_memory(a, b, max_work=None): return _mx_nd_np.may_share_memory(a, b, max_work) -def diff(a, n=1, axis=-1, prepend=None, append=None): +def diff(a, n=1, axis=-1, prepend=None, append=None): # pylint: disable=redefined-outer-name r""" numpy.diff(a, n=1, axis=-1, prepend=, append=) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index cfab2a49699d..cdd21af829de 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -86,6 +86,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'argmin', 'argmax', 'around', + 'append', 'broadcast_arrays', 'broadcast_to', 'clip', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 7469875f267a..d3837d2bd1dd 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -35,7 +35,7 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', - 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', + 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', @@ -2992,6 +2992,7 @@ def vsplit(ary, indices_or_sections): @set_module('mxnet.symbol.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. + Parameters ---------- a1, a2, ... : sequence of array_like @@ -3004,12 +3005,69 @@ def concatenate(seq, axis=0, out=None): If provided, the destination to place the result. The shape must be correct, matching that of what concatenate would have returned if no out argument were specified. + Returns ------- res : ndarray The concatenated array. + + Examples + -------- + >>> a = np.array([[1, 2], [3, 4]]) + >>> b = np.array([[5, 6]]) + >>> np.concatenate((a, b), axis=0) + array([[1., 2.], + [3., 4.], + [5., 6.]]) + + >>> np.concatenate((a, b), axis=None) + array([1., 2., 3., 4., 5., 6.]) + + >>> np.concatenate((a, b.T), axis=1) + array([[1., 2., 5.], + [3., 4., 6.]]) + """ + return _npi.concatenate(*seq, axis=axis, out=out) + + +@set_module('mxnet.symbol.numpy') +def append(arr, values, axis=None): # pylint: disable=redefined-outer-name + """ + Append values to the end of an array. + + Parameters + ---------- + arr : ndarray + Values are appended to a copy of this array. + values : ndarray + These values are appended to a copy of `arr`. It must be of the + correct shape (the same shape as `arr`, excluding `axis`). If + `axis` is not specified, `values` can be any shape and will be + flattened before use. + axis : int, optional + The axis along which `values` are appended. If `axis` is not + given, both `arr` and `values` are flattened before use. + + Returns + ------- + append : ndarray + A copy of `arr` with `values` appended to `axis`. Note that + `append` does not occur in-place: a new array is allocated and + filled. If `axis` is None, `out` is a flattened array. + + Examples + -------- + >>> np.append(np.array([1, 2, 3]), np.array([[4, 5, 6],[7, 8, 9]])) + array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + + When `axis` is specified, `values` must have the correct shape. + + >>> np.append(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9]]), axis=0) + array([[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.]]) """ - return _npi.concatenate(*seq, dim=axis, out=out) + return _npi.concatenate(arr, values, axis=axis, out=None) @set_module('mxnet.symbol.numpy') @@ -4665,7 +4723,7 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b) -def diff(a, n=1, axis=-1, prepend=None, append=None): +def diff(a, n=1, axis=-1, prepend=None, append=None): # pylint: disable=redefined-outer-name r""" numpy.diff(a, n=1, axis=-1, prepend=, append=) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 2545adcb3555..a9828f40436d 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -864,6 +864,87 @@ inline void HSplitOpBackward(const nnvm::NodeAttrs &attrs, } SplitOpBackwardImpl(attrs, ctx, inputs, req, outputs, real_axis); } + +struct NumpyConcatenateParam : public dmlc::Parameter { + int num_args; + dmlc::optional axis; + DMLC_DECLARE_PARAMETER(NumpyConcatenateParam) { + DMLC_DECLARE_FIELD(num_args) + .set_lower_bound(1) + .describe("Number of inputs to be concated."); + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional(0)) + .describe("The axis along which `values` are appended. If `axis` is not" + "given, both `arr` and `values` are flattened before use."); + } +}; + +template +void NumpyConcatenateForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyConcatenateParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), param.num_args); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (!param.axis.has_value()) { + data[i] = inputs[i].reshape(Shape1(inputs[i].shape_.Size())); + } else { + data[i] = inputs[i]; + } + } + + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = param.axis.has_value() ? param.axis.value() : 0; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Forward(ctx, data, req, outputs); + }); +} + +template +void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyConcatenateParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), param.num_args); + CHECK_EQ(req.size(), param.num_args); + + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (!param.axis.has_value()) { + data[i] = outputs[i].reshape(Shape1(outputs[i].shape_.Size())); + } else { + data[i] = outputs[i]; + } + } + + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = param.axis.has_value() ? param.axis.value() : 0; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Backward(ctx, inputs[0], req, data); + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 18594cd9cff1..3967cde91d2a 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -456,10 +456,6 @@ NNVM_REGISTER_OP(_np_squeeze) .add_argument("a", "NDArray-or-Symbol", "data to squeeze") .add_arguments(SqueezeParam::__FIELDS__()); -bool ConcatShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape); - bool DStackShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { @@ -525,6 +521,84 @@ bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type); +bool NumpyConcatenateType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const NumpyConcatenateParam& param = nnvm::get(attrs.parsed); + const int num_args = param.num_args; + CHECK_EQ(in_type->size(), num_args); + CHECK_EQ(out_type->size(), 1); + int dtype = -1; + for (int i = 0; i < num_args; i++) { + if (dtype == -1) { + dtype = in_type->at(i); + } + } + if (dtype == -1) { + dtype = out_type->at(0); + } + for (int i = 0; i < num_args; i++) { + TYPE_ASSIGN_CHECK(*in_type, i, dtype); + } + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return dtype != -1; +} + +bool NumpyConcatenateShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + using namespace mshadow; + const NumpyConcatenateParam& param_ = nnvm::get(attrs.parsed); + const int num_args = param_.num_args; + CHECK_EQ(in_shape->size(), num_args); + + int param_axis; + if (!(param_.axis.has_value())) { + for (int i = 0 ; i < num_args ; ++i) { + (*in_shape)[i] = Shape1((*in_shape)[i].Size()); + } + param_axis = 0; + } else { + param_axis = param_.axis.value(); + } + + mxnet::TShape dshape; + dim_t size = 0; + bool has_unknown_dim_size = false; + int axis = -1; + for (int i = 0; i < num_args; ++i) { + mxnet::TShape tmp = (*in_shape)[i]; + if (tmp.ndim() > 0) { + axis = CheckAxis(param_axis, tmp.ndim()); + has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size; + size += tmp[axis]; + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + } + + mxnet::TShape tmp = (*out_shape)[0]; + if (tmp.ndim() > 0) { + axis = CheckAxis(param_axis, tmp.ndim()); + tmp[axis] = -1; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == -1) return false; + CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated"; + + for (int i = 0; i < num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!has_unknown_dim_size) dshape[axis] = size; + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + + return shape_is_known(dshape); +} + struct NumpyConcatGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, @@ -535,17 +609,19 @@ struct NumpyConcatGrad { } }; +DMLC_REGISTER_PARAMETER(NumpyConcatenateParam); + NNVM_REGISTER_OP(_npi_concatenate) .describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { - const ConcatParam& params = nnvm::get(attrs.parsed); + const NumpyConcatenateParam& params = nnvm::get(attrs.parsed); return params.num_args; }) .set_num_outputs(1) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - const ConcatParam& params = nnvm::get(attrs.parsed); + const NumpyConcatenateParam& params = nnvm::get(attrs.parsed); std::vector ret; for (int i = 0; i < params.num_args; ++i) { ret.push_back(std::string("data") + std::to_string(i)); @@ -557,21 +633,22 @@ NNVM_REGISTER_OP(_npi_concatenate) return std::vector{"out"}; }) .set_attr("key_var_num_args", "num_args") -.set_attr("FInferType", ConcatType) -.set_attr("FInferShape", ConcatShape) -.set_attr("FCompute", ConcatCompute) -.set_attr("FGradient", NumpyConcatGrad{"_backward_np_concat"}) +.set_attr("FInferType", NumpyConcatenateType) +.set_attr("FInferShape", NumpyConcatenateShape) +.set_attr("FCompute", NumpyConcatenateForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_concat"}) .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") .add_arguments(ConcatParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_np_concat) +.set_num_inputs(1) .set_num_outputs([](const NodeAttrs& attrs) { - const ConcatParam& params = nnvm::get(attrs.parsed); + const NumpyConcatenateParam& params = nnvm::get(attrs.parsed); return params.num_args; }) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("TIsBackward", true) -.set_attr("FCompute", ConcatGradCompute); +.set_attr("FCompute", NumpyConcatenateBackward); NNVM_REGISTER_OP(_npi_stack) .describe(R"code(Join a sequence of arrays along a new axis. diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index fccc8f257e64..7ca205565413 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -39,10 +39,10 @@ NNVM_REGISTER_OP(_np_squeeze) .set_attr("FCompute", UnaryOp::IdentityCompute); NNVM_REGISTER_OP(_npi_concatenate) -.set_attr("FCompute", ConcatCompute); +.set_attr("FCompute", NumpyConcatenateForward); NNVM_REGISTER_OP(_backward_np_concat) -.set_attr("FCompute", ConcatGradCompute); +.set_attr("FCompute", NumpyConcatenateBackward); NNVM_REGISTER_OP(_npi_stack) .set_attr("FCompute", StackOpForward); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 15912dc47ad3..8416b1a9099f 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -83,6 +83,25 @@ def _add_workload_concatenate(array_pool): OpArgMngr.add_workload('concatenate', (a0.T, a1.T, a2.T), axis=0) out = np.empty(4, np.float32) OpArgMngr.add_workload('concatenate', (np.array([1, 2]), np.array([3, 4])), out=out) + OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']], axis=None) + OpArgMngr.add_workload('concatenate', (np.arange(4).reshape((2, 2)), np.arange(4).reshape((2, 2))), axis=None) + OpArgMngr.add_workload('concatenate', (a23, a13), axis=None) + + +def _add_workload_append(): + def get_new_shape(shape, axis): + shape_lst = list(shape) + if axis is not None: + shape_lst[axis] = _np.random.randint(0, 3) + return tuple(shape_lst) + + for shape in [(0, 0), (2, 3), (2, 1, 3)]: + for axis in [0, 1, None]: + a = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)) + b = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)) + OpArgMngr.add_workload('append', a, b, axis=axis) + + OpArgMngr.add_workload('append', np.array([]), np.array([])) def _add_workload_copy(): @@ -1125,6 +1144,7 @@ def _prepare_workloads(): _add_workload_argmin() _add_workload_argmax() _add_workload_around() + _add_workload_append() _add_workload_broadcast_arrays(array_pool) _add_workload_broadcast_to() _add_workload_clip() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 605fa85e1f77..a2716fb5363f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1908,43 +1908,112 @@ def hybrid_forward(self, F, a, *args): def get_new_shape(shape, axis): shape_lst = list(shape) - shape_lst[axis] = random.randint(0, 3) + if axis is not None: + shape_lst[axis] = random.randint(0, 3) return tuple(shape_lst) - for shape in [(0, 0), (2, 3)]: + for shape in [(0, 0), (2, 3), (2, 1, 3)]: for hybridize in [True, False]: - for axis in range(2): - # test gluon - test_concat = TestConcat(axis=axis) - if hybridize: - test_concat.hybridize() + for axis in [0, 1, None]: + for grad_req in ['write', 'add', 'null']: + # test gluon + test_concat = TestConcat(axis=axis) + if hybridize: + test_concat.hybridize() + + grad_req_c = grad_req + grad_req_d = grad_req + if grad_req == 'null': + ide = random.randint(0, 2) + grad_req_c = 'write' if ide == 0 else 'add' + grad_req_c = 'write' if ide == 1 else 'add' + + a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + a.attach_grad(grad_req) + b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + b.attach_grad(grad_req) + c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + c.attach_grad(grad_req_c) + d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + d.attach_grad(grad_req_d) + expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - a.attach_grad() - b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - b.attach_grad() - c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - c.attach_grad() - d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - d.attach_grad() - expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - with mx.autograd.record(): - y = test_concat(a, b, c, d) + with mx.autograd.record(): + y = test_concat(a, b, c, d) + + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + + y.backward() + if grad_req != 'null': + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + if grad_req != 'null': + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + if grad_req_c != 'null': + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + if grad_req_d != 'null': + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) - assert y.shape == expected_ret.shape - assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + # test imperative + mx_out = np.concatenate([a, b, c, d], axis=axis) + np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) - y.backward() - assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) - assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) - assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) - assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_append(): + class TestAppend(HybridBlock): + def __init__(self, axis=None): + super(TestAppend, self).__init__() + self._axis = axis - # test imperative - mx_out = np.concatenate([a, b, c, d], axis=axis) - np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + def hybrid_forward(self, F, a, b): + return F.np.append(a, b, axis=self._axis) + + def get_new_shape(shape, axis): + shape_lst = list(shape) + if axis is not None: + shape_lst[axis] = random.randint(0, 3) + return tuple(shape_lst) + + for shape in [(0, 0), (2, 3), (2, 1, 3)]: + for hybridize in [True, False]: + for axis in [0, 1, None]: + for grad_req_a in ['write', 'add', 'null']: + if grad_req_a == 'null': + continue + #set grad_req + grad_req_b = grad_req_a + if grad_req_a == 'null': + ide = random.randint(0, 2) + grad_req_b = 'write' if ide == 0 else 'add' + + #test gluon + test_append = TestAppend(axis=axis) + if hybridize: + test_append.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + a.attach_grad(grad_req=grad_req_a) + b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + b.attach_grad(grad_req=grad_req_b) + expected_ret = _np.append(a.asnumpy(), b.asnumpy(), axis=axis) + + with mx.autograd.record(): + y = test_append(a, b) + + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + y.backward() + + if grad_req_a != 'null': + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + #test imperative + mx_out = np.append(a, b, axis=axis) + np_out = _np.append(a.asnumpy(), b.asnumpy(), axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @with_seed()