diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 9a62620da85c..fb329f1865a9 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -47,7 +47,7 @@ "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram", - "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"] + "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"] _STORAGE_TYPE_UNDEFINED = -1 _STORAGE_TYPE_DEFAULT = 0 @@ -1133,6 +1133,14 @@ def split(self, *args, **kwargs): """ return op.split(self, *args, **kwargs) + def split_v2(self, *args, **kwargs): + """Convenience fluent method for :py:func:`split_v2`. + + The arguments are the same as for :py:func:`split_v2`, with + this array as data. + """ + return split_v2(self, *args, **kwargs) + def slice(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice`. @@ -3901,6 +3909,12 @@ def histogram(a, bins=10, range=None): Values outside the range are ignored. The first element of the range must be less than or equal to the second. range affects the automatic bin computation as well, the range will be equally divided by the number of bins. + + Returns + ------- + NDArray + A created array. + """ # pylint: disable= no-member, protected-access @@ -3916,6 +3930,51 @@ def histogram(a, bins=10, range=None): raise ValueError("bins argument should be either an integer or an NDArray") # pylint: enable= no-member, protected-access, redefined-builtin +def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False): + """Split an array into multiple sub-arrays. + + Parameters + ---------- + ary : NDArray + Array to be divided into sub-arrays. + indices_or_sections : int or tuple of ints + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + squeeze_axis: boolean, optional + Whether to squeeze the axis of sub-arrays or not, only useful when size + of the sub-arrays are 1 on the `axis`. Default is False. + + Returns + ------- + NDArray + A created array. + + """ + indices = [] + axis_size = ary.shape[axis] + if isinstance(indices_or_sections, int): + sections = indices_or_sections + if axis_size % sections: + raise ValueError('array split does not result in an equal division') + section_size = int(axis_size / sections) + indices = [i * section_size for i in range(sections)] + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + return _internal._split_v2(ary, indices, axis, squeeze_axis) + PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) _c_str_dltensor = c_str('dltensor') _c_str_used_dltensor = c_str('used_dltensor') diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 530d72796c00..43de0c9d7535 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -48,7 +48,7 @@ __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", - "histogram"] + "histogram", "split_v2"] class Symbol(SymbolBase): @@ -1855,6 +1855,14 @@ def split(self, *args, **kwargs): """ return op.split(self, *args, **kwargs) + def split_v2(self, *args, **kwargs): + """Convenience fluent method for :py:func:`split_v2`. + + The arguments are the same as for :py:func:`split_v2`, with + this array as data. + """ + return split_v2(self, *args, **kwargs) + def slice(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice`. @@ -2958,6 +2966,11 @@ def histogram(a, bins=10, range=None, **kwargs): Values outside the range are ignored. The first element of the range must be less than or equal to the second. range affects the automatic bin computation as well, the range will be equally divided by the number of bins. + + Returns + ------- + out : Symbol + The created Symbol """ if isinstance(bins, Symbol): return _internal._histogram(data=a, bins=bins, **kwargs) @@ -2967,4 +2980,44 @@ def histogram(a, bins=10, range=None, **kwargs): return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs) raise ValueError("bins argument should be either an integer or an NDArray") +def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False): + """Split an array into multiple sub-arrays. + + Parameters + ---------- + ary : NDArray + Array to be divided into sub-arrays. + indices_or_sections : int or tuple of ints + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + squeeze_axis: boolean, optional + Whether to squeeze the axis of sub-arrays or not, only useful when size + of the sub-arrays are 1 on the `axis`. Default is False. + + Returns + ------- + out : Symbol + The created Symbol + """ + indices = [] + sections = 0 + if isinstance(indices_or_sections, int): + sections = indices_or_sections + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + return _internal._split_v2(ary, indices, axis, squeeze_axis, sections) + _set_symbol_class(Symbol) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 8b575ca75365..97c4fa55681c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2527,6 +2527,319 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs, }); } +namespace split_enum { +enum SplitOpInputs {kData}; +} // namespace split_enum + +struct SplitParam : public dmlc::Parameter { + TShape indices; + int axis; + bool squeeze_axis; + int sections; + DMLC_DECLARE_PARAMETER(SplitParam) { + DMLC_DECLARE_FIELD(indices) + .describe("Indices of splits. The elements should denote the boundaries of at which split" + " is performed along the `axis`."); + DMLC_DECLARE_FIELD(axis).set_default(1) + .describe("Axis along which to split."); + DMLC_DECLARE_FIELD(squeeze_axis).set_default(0) + .describe("If true, Removes the axis with length 1 from the shapes of the output arrays." + " **Note** that setting `squeeze_axis` to ``true`` removes axis with length 1" + " only along the `axis` which it is split." + " Also `squeeze_axis` can be set to ``true``" + " only if ``input.shape[axis] == num_outputs``."); + DMLC_DECLARE_FIELD(sections).set_default(0) + .describe("Number of sections if equally splitted. Default to 0 which means split by indices."); + } +}; // struct SplitParam + +inline TShape GetSplitIndices(const TShape& ishape, int axis, int sections) { + TShape indices(sections+1); + indices[0] = 0; + int64_t section_size = ishape[axis] / sections; + for (int i = 0; i < sections; ++i) { + indices[i+1] = section_size * (i + 1); + } + return indices; +} + +inline bool SplitOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + int dtype = (*in_attrs)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + const SplitParam& param = nnvm::get(attrs.parsed); + out_attrs->clear(); + int num_outputs = (param.sections > 0) ? param.sections : param.indices.ndim(); + for (int i = 0; i < num_outputs; ++i) { + out_attrs->push_back(dtype); + } + return true; +} + +inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + using namespace mshadow; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + TShape dshape = in_attrs->at(split_enum::kData); + TShape ishape = in_attrs->at(split_enum::kData); + if (dshape.ndim() == 0) return false; + if (param.axis >= 0) { + CHECK_LT(static_cast(param.axis), dshape.ndim()); + } else { + CHECK_LT(param.axis + dshape.ndim(), dshape.ndim()); + } + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += dshape.ndim(); + } + const TShape indices = + (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices; + int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim(); + // Pre-compute squeezed output shape for future usage + TShape squeezed_dshape = dshape; + for (int d = real_axis; d < static_cast(squeezed_dshape.ndim()) - 1; ++d) { + squeezed_dshape[d] = squeezed_dshape[d+1]; + } + squeezed_dshape = TShape(&squeezed_dshape[0], &squeezed_dshape[squeezed_dshape.ndim()-1]); + // Assign shape to every output + for (int i = 0; i < num_outputs; ++i) { + int start = indices[i]; + int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis]; + CHECK(start < end) + << "start " << start << " is not less than end " << end << "for subarray " << i; + CHECK(end <= ishape[real_axis]) + << "end " << end << " is no less than the size of the axis " << ishape[real_axis]; + dshape[real_axis] = (end - start); + if (param.squeeze_axis) { + CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start; + SHAPE_ASSIGN_CHECK(*out_attrs, i, squeezed_dshape); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, i, dshape); + } + } + TShape back_calculate_dshape = ishape; + back_calculate_dshape[real_axis] = 0; + for (int d = 0; d < real_axis; ++d) { + back_calculate_dshape[d] = (*out_attrs)[0][d]; + } + if (param.squeeze_axis) { + back_calculate_dshape[real_axis] = num_outputs; + } else { + for (int i = 0; i < num_outputs; ++i) { + back_calculate_dshape[real_axis] += (*out_attrs)[i][real_axis]; + } + } + for (int d = real_axis + 1; d < static_cast(ishape.ndim()); ++d) { + if (param.squeeze_axis) { + back_calculate_dshape[d] = (*out_attrs)[0][d - 1]; + } else { + back_calculate_dshape[d] = (*out_attrs)[0][d]; + } + } + SHAPE_ASSIGN_CHECK(*in_attrs, split_enum::kData, back_calculate_dshape); + return true; +} + +struct SplitKernel { + /*! + * \brief Map function for forward split_v2 operator + * \param i global thread id + * \param in_data ptr to input buffer + * \param out_data ptr to ptr of outputs buffer + * \param indices ptr to indices buffer + * \param num_sections # of sections after split + * \param axis_size size of axis to be splitted on + * \param trailing_size step size within the data buffer of the axis to be splitted on + */ + template + static MSHADOW_XINLINE void Map(size_t i, + const DType *in_data, DType** out_data, const size_t* indices, + const size_t num_sections, const size_t axis_size, + const size_t trailing_size) { + size_t idx = i / trailing_size % axis_size; + size_t target = 0; + for (size_t section = 0; + section < num_sections && indices[section] <= idx; + target = section++) {} + DType* target_data = out_data[target]; + const size_t mid_idx = idx - indices[target]; + const size_t head_idx = i / (trailing_size * axis_size); + const size_t tail_idx = i % trailing_size; + const size_t section_size = indices[target + 1] - indices[target]; + const size_t target_idx = + head_idx * trailing_size * section_size + mid_idx * trailing_size + tail_idx; + target_data[target_idx] = in_data[i]; + } +}; + +struct ConcatenateKernel { + /*! + * \brief Map function for backward split_v2 operator + * \param i global thread id + * \param out_grad ptr to ptr of out grads buffer + * \param in_grad ptr to input grad buffer + * \param indices ptr to indices buffer + * \param num_sections # of sections after split + * \param axis_size size of axis to be splitted on + * \param trailing_size step size within the data buffer of the axis to be splitted on + */ + template + static MSHADOW_XINLINE void Map(size_t i, + DType** out_grad, DType* in_grad, const size_t* indices, + const size_t num_sections, const size_t axis_size, + const size_t trailing_size) { + size_t idx = i / trailing_size % axis_size; + size_t src = 0; + for (size_t section = 0; + section < num_sections && indices[section] <= idx; + src = section++) {} + DType* src_grad = out_grad[src]; + const size_t mid_idx = idx - indices[src]; + const size_t head_idx = i / (trailing_size * axis_size); + const size_t tail_idx = i % trailing_size; + const size_t section_size = indices[src + 1] - indices[src]; + const size_t src_idx = + head_idx * trailing_size * section_size + mid_idx * trailing_size + tail_idx; + in_grad[i] = src_grad[src_idx]; + } +}; + +template +inline void SplitOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()); + Stream *s = ctx.get_stream(); + const TBlob& input_data = inputs[split_enum::kData]; + size_t leading = 1, trailing = 1; + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += input_data.ndim(); + } + CHECK_LT(real_axis, input_data.ndim()); + size_t mid = input_data.shape_[real_axis]; + for (int i = 0; i < real_axis; ++i) { + leading *= input_data.shape_[i]; + } + for (int i = real_axis + 1; i < input_data.ndim(); ++i) { + trailing *= input_data.shape_[i]; + } + + size_t workspace_size = 0; + const TShape& ishape = input_data.shape_; + const TShape split_pts = + (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices; + std::vector indices; + for (const auto& section : split_pts) { + indices.push_back(section); + } + if (param.sections == 0) { + indices.push_back(ishape[real_axis]); + } + workspace_size += indices.size() * sizeof(size_t); + MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, { + std::vector output_data; + for (const TBlob& data : outputs) { + output_data.push_back(data.dptr()); + } + workspace_size += output_data.size() * sizeof(DType*); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor indices_cpu_tensor(indices.data(), Shape1(indices.size())); + Tensor indices_xpu_tensor( + reinterpret_cast(workspace.dptr_), Shape1(indices.size())); + Tensor ptrs_cpu_tensor(output_data.data(), Shape1(output_data.size())); + Tensor ptrs_xpu_tensor( + reinterpret_cast(workspace.dptr_ + indices.size() * sizeof(size_t)), + Shape1(output_data.size())); + mshadow::Copy(indices_xpu_tensor, indices_cpu_tensor, s); + mshadow::Copy(ptrs_xpu_tensor, ptrs_cpu_tensor, s); + Kernel::Launch( + s, input_data.Size(), input_data.dptr(), ptrs_xpu_tensor.dptr_, + indices_xpu_tensor.dptr_, indices.size() - 1, mid, trailing); + }); +} + +template +inline void SplitOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()) + << "out grad vector size mush match the output size"; + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + TBlob input_grad = outputs[split_enum::kData]; + size_t leading = 1, trailing = 1; + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += input_grad.ndim(); + } + CHECK_LT(real_axis, input_grad.ndim()); + size_t mid = input_grad.shape_[real_axis]; + for (int i = 0; i < real_axis; ++i) { + leading *= input_grad.shape_[i]; + } + for (int i = real_axis + 1; i < input_grad.ndim(); ++i) { + trailing *= input_grad.shape_[i]; + } + + size_t workspace_size = 0; + const TShape& ishape = input_grad.shape_; + const TShape split_pts = + (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices; + std::vector indices; + for (const auto& section : split_pts) { + indices.push_back(section); + } + if (param.sections == 0) { + indices.push_back(ishape[real_axis]); + } + workspace_size += indices.size() * sizeof(size_t); + MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, { + std::vector out_grads; + for (const TBlob& output_grad : inputs) { + out_grads.push_back(output_grad.dptr()); + } + workspace_size += out_grads.size() * sizeof(DType*); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor indices_cpu_tensor(indices.data(), Shape1(indices.size())); + Tensor indices_xpu_tensor( + reinterpret_cast(workspace.dptr_), Shape1(indices.size())); + Tensor ptrs_cpu_tensor(out_grads.data(), Shape1(inputs.size())); + Tensor ptrs_xpu_tensor( + reinterpret_cast(workspace.dptr_ + indices.size() * sizeof(size_t)), + Shape1(inputs.size())); + mshadow::Copy(indices_xpu_tensor, indices_cpu_tensor, s); + mshadow::Copy(ptrs_xpu_tensor, ptrs_cpu_tensor, s); + Kernel::Launch( + s, input_grad.Size(), ptrs_xpu_tensor.dptr_, input_grad.dptr(), + indices_xpu_tensor.dptr_, indices.size() - 1, mid, trailing); + }); +} + +inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) { + const SplitParam& param = nnvm::get(attrs.parsed); + return (param.sections > 0) ? param.sections : param.indices.ndim(); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index ed8912f7b7be..e5d354be629f 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -103,6 +103,7 @@ DMLC_REGISTER_PARAMETER(ReverseParam); DMLC_REGISTER_PARAMETER(StackParam); DMLC_REGISTER_PARAMETER(SqueezeParam); DMLC_REGISTER_PARAMETER(DepthToSpaceParam); +DMLC_REGISTER_PARAMETER(SplitParam); #if MXNET_USE_MKLDNN == 1 void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) { @@ -1071,8 +1072,8 @@ Example:: [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]] - - + + space_to_depth(x, 2) = [[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], @@ -1100,5 +1101,103 @@ Example:: .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(DepthToSpaceParam::__FIELDS__()); +NNVM_REGISTER_OP(_split_v2) +.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + +Example:: + + x = [[[ 1.] + [ 2.]] + [[ 3.] + [ 4.]] + [[ 5.] + [ 6.]]] + x.shape = (3, 2, 1) + + y = split_v2(x, axis=1, indices_or_sections=2) // a list of 2 arrays with shape (3, 1, 1) + y = [[[ 1.]] + [[ 3.]] + [[ 5.]]] + + [[[ 2.]] + [[ 4.]] + [[ 6.]]] + + y[0].shape = (3, 1, 1) + + z = split_v2(x, axis=0, indices_or_sections=3) // a list of 3 arrays with shape (1, 2, 1) + z = [[[ 1.] + [ 2.]]] + + [[[ 3.] + [ 4.]]] + + [[[ 5.] + [ 6.]]] + + z[0].shape = (1, 2, 1) + + w = split_v2(x, axis=0, indices_or_sections=(1,)) // a list of 2 arrays with shape [(1, 2, 1), (2, 2, 1)] + w = [[[ 1.] + [ 2.]]] + + [[[3.] + [4.]] + + [[5.] + [6.]]] + + w[0].shape = (1, 2, 1) + w[1].shape = (2, 2, 1) + +`squeeze_axis=True` removes the axis with length 1 from the shapes of the output arrays. +**Note** that setting `squeeze_axis` to ``1`` removes axis with length 1 only +along the `axis` which it is split. +Also `squeeze_axis` can be set to true only if ``input.shape[axis] == indices_or_sections``. + +Example:: + + z = split_v2(x, axis=0, indices_or_sections=3, squeeze_axis=1) // a list of 3 arrays with shape (2, 1) + z = [[ 1.] + [ 2.]] + + [[ 3.] + [ 4.]] + + [[ 5.] + [ 6.]] + z[0].shape = (2, 1) + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(SplitNumOutputs) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", SplitOpShape) +.set_attr("FInferType", SplitOpType) +.set_attr("FCompute", SplitOpForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FGradient", ElemwiseGradUseNone{"_split_v2_backward"}) +.add_argument("data", "NDArray-or-Symbol", "The input") +.add_arguments(SplitParam::__FIELDS__()); + +NNVM_REGISTER_OP(_split_v2_backward) +.set_attr_parser(ParamParser) +.set_num_inputs(SplitNumOutputs) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", SplitOpBackward); + + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 4e31a4cf1155..87311276da26 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -217,5 +217,11 @@ NNVM_REGISTER_OP(depth_to_space) NNVM_REGISTER_OP(space_to_depth) .set_attr("FCompute", SpaceToDepthOpForward); +NNVM_REGISTER_OP(_split_v2) +.set_attr("FCompute", SplitOpForward); + +NNVM_REGISTER_OP(_split_v2_backward) +.set_attr("FCompute", SplitOpBackward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 0aa48553901b..7176b1888607 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1056,7 +1056,7 @@ def test_output(): @with_seed() def test_ndarray_fluent(): has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod', - 'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split', + 'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split', 'split_v2', 'broadcast_axes', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like', 'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin', 'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', @@ -1093,6 +1093,8 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False): check_fluent_regular('repeat', {'repeats': 3}) check_fluent_regular('transpose', {'axes': (1,0,2)}) check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 6)) + check_fluent_regular('split_v2', {'axis': 2, 'indices_or_sections': 3}, shape=(5, 17, 6)) + check_fluent_regular('split_v2', {'axis': 2, 'indices_or_sections': (1, 3, 5)}, shape=(5, 17, 6)) check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, shape=(5, 17, 6)) check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7}) check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.nd.zeros((3, 3))}) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index cda801c25bbb..67aeddf19c44 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7248,6 +7248,25 @@ def f_sm_ce(data, label): check_symbolic_forward(sym, {'data' : np_data, 'label' : np_label}, [np.array([f_sm_ce(np_sm, np_one_hot_label)])], rtol=1e-3, atol=1e-5) +@with_seed() +def test_split_v2(): + dim = random.randint(2, 6) + shape = rand_shape_nd(dim) + axis = random.randint(-dim, dim-1) + axis_size = shape[axis] + samples = random.randint(0, axis_size - 1) + indices = sorted(random.sample([i for i in range(1, axis_size)], samples)) + indices = tuple(indices) + mx_data = rand_ndarray(shape) + np_data = mx_data.asnumpy() + np_out = np.split(np_data, indices_or_sections=indices, axis=axis) + data = mx.sym.Variable("data") + sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis) + check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5) + out_grad = [np.ones(arr.shape) for arr in np_out] + check_symbolic_backward(sym, {"data": mx_data}, out_grad, [np.concatenate(out_grad, axis=axis)]) + + @with_seed() def test_invalid_kernel_size(): invalid_kernel_size = 28