diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 0d0e3b64491b..d1d67be06b05 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -20,6 +20,107 @@ """Doc placeholder for numpy ops with prefix _np.""" +def _np_all(a, axis=None, keepdims=False, out=None): + """ + Test whether all array elements along a given axis evaluate to True. + + Parameters + ---------- + a : array_like + Input array or object that can be converted to an array. + axis : None or int or tuple of ints, optional + Axis or axes along which a logical AND reduction is performed. + The default (axis = None) is to perform a logical AND over + all the dimensions of the input array. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + out : ndarray, optional + Alternate output array in which to place the result. It must have + the same shape as the expected output and its type is preserved + + Returns + -------- + all : ndarray, bool + A new boolean or array is returned unless out is specified, + in which case a reference to out is returned. + + Examples: + --------- + >>> np.all([[True,False],[True,True]]) + False + + >>> np.all([[True,False],[True,True]], axis=0) + array([ True, False]) + + >>> np.all([-1, 4, 5]) + True + + >>> np.all([1.0, np.nan]) + True + + >>> o=np.array(False) + >>> z=np.all([-1, 4, 5], out=o) + >>> id(z), id(o), z + (28293632, 28293632, array(True)) # may vary + """ + pass + +def _np_any(a, axis=None, keepdims=False, out=None): + """ + Test whether any array element along a given axis evaluates to True. + Returns single boolean unless axis is not None + + Parameters + ---------- + a : array_like + Input array or object that can be converted to an array. + axis : None or int or tuple of ints, optional + Axis or axes along which a logical AND reduction is performed. + The default (axis = None) is to perform a logical AND over + all the dimensions of the input array. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + out : ndarray, optional + Alternate output array in which to place the result. It must have + the same shape as the expected output and its type is preserved + + Returns + -------- + any : bool or ndarray + A new boolean or ndarray is returned unless out is specified, + in which case a reference to out is returned. + + Examples: + --------- + >>> np.any([[True, False], [True, True]]) + True + + >>> np.any([[True, False], [False, False]], axis=0) + array([ True, False]) + + >>> np.any([-1, 0, 5]) + True + + >>> np.any(np.nan) + True + + >>> o=np.array(False) + >>> z=np.any([-1, 4, 5], out=o) + >>> z, o + (array(True), array(True)) + >>> # Check now that z is a reference to o + >>> z is o + True + >>> id(z), id(o) # identity of z and o # doctest: +SKIP + (191614240, 191614240) + """ + pass + + def _np_cumsum(a, axis=None, dtype=None, out=None): """ Return the cumulative sum of the elements along a given axis. diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index bd5c388a5100..8594837a46b4 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -83,6 +83,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs): _NUMPY_ARRAY_FUNCTION_LIST = [ + 'all', + 'any', 'argmin', 'argmax', 'around', diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index cf35e8858039..5c92b78d9985 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1089,6 +1089,14 @@ struct minimum : public mxnet_op::tunable { } }; +/*! \brief boolean any/all kernel that determines whether elem is NonZero */ +struct NonZero { + template + MSHADOW_XINLINE static bool Map(DType a) { + return (a != DType(0)); + } +}; + /*! \brief sum reducer that ignores NaN values in the input */ struct nansum { /*! \brief do reduction into dst */ diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index b15117f9f83b..d7752c4759db 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -1148,6 +1148,29 @@ struct set_to_int : public tunable { */ using set_zero = set_to_int<0>; using set_one = set_to_int<1>; + +/*! + * \brief Set to immediate scalar value kernel + * \tparam val Scalar immediate + */ +template +struct set_to_bool : public tunable { + // mxnet_op version (when used directly with Kernel<>::Launch()) */ + template + MSHADOW_XINLINE static void Map(index_t i, DType *out) { + out[i] = DType(val); + } + // mshadow_op version (when used with op_with_req<>) + MSHADOW_XINLINE static int Map() { + return val; + } +}; + +/*! + * \brief Special-case kernel shortcut for setting to true and false + */ +using set_true = set_to_bool; +using set_false = set_to_bool; } // namespace mxnet_op } // namespace op diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 7d0025a62ad2..0efe2c2aa3df 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -86,6 +86,21 @@ struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter { + dmlc::optional> axis; + bool keepdims; + DMLC_DECLARE_PARAMETER(NumpyReduceAxesBoolParam) { + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional>()) + .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum " + "all of the elements of the input array. If axis is negative it counts from the " + "last to the first axis."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + } +}; + inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, const dmlc::optional>& axis, bool keepdims) { @@ -173,6 +188,20 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, return shape_is_known(out_attrs->at(0)); } +inline bool NumpyReduceAxesBoolShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const NumpyReduceAxesBoolParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -298,6 +327,30 @@ void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); } +template +void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyReduceAxesBoolParam& param = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) { + using namespace mxnet_op; + Kernel::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr()); + return; + } + if (param.axis.has_value() && param.axis.value().ndim() == 0) { + UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); + } + TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + ReduceAxesComputeBoolImpl(ctx, inputs, req, outputs, small); +} template inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cc b/src/operator/numpy/np_broadcast_reduce_op_boolean.cc new file mode 100644 index 000000000000..7529c0d4e1d3 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_boolean.cc + * \brief CPU Implementation of broadcast and reduce functions based on boolean. + */ + +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +inline bool NumpyReduceAxesBoolType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(NumpyReduceAxesBoolParam); + +NNVM_REGISTER_OP(_np_any) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("THasDeterministicOutput", true) +.set_attr("FInferShape", NumpyReduceAxesBoolShape) +.set_attr("FInferType", NumpyReduceAxesBoolType) +.set_attr("FCompute", NumpyReduceAxesBoolCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__()); + +NNVM_REGISTER_OP(_np_all) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("THasDeterministicOutput", true) +.set_attr("FInferShape", NumpyReduceAxesBoolShape) +.set_attr("FInferType", NumpyReduceAxesBoolType) +.set_attr("FCompute", NumpyReduceAxesBoolCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu new file mode 100644 index 000000000000..2c206bf88b2f --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_broadcast_reduce_op_boolean.cu + * \brief GPU Implementation of broadcast and reduce functions based on boolean. + */ + +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_np_any) +.set_attr("FCompute", NumpyReduceAxesBoolCompute); + +NNVM_REGISTER_OP(_np_all) +.set_attr("FCompute", NumpyReduceAxesBoolCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index db898f8840f0..42971403f2ee 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -403,6 +403,8 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NO IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_bool); // NOLINT() +IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_bool); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT() diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index 41940e4b1e07..6cd7dd50657a 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -634,6 +634,16 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, } } +template +void ReduceBool(Stream *s, const TBlob& small, const OpReqType req, + const Tensor& workspace, const TBlob& big) { + if (req == kNullOp) return; + cudaStream_t stream = Stream::GetStream(s); + ReduceImplConfig config = + ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); + ReduceImpl(stream, small, req, big, workspace, config); +} + template void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) {}; diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 0a20e1263fbf..841fbcd28a68 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -255,6 +255,18 @@ void Reduce(Stream* s, const TBlob& small, const OpReqType req, } } +template +void ReduceBool(Stream* s, const TBlob& small, const OpReqType req, + const Tensor& workspace, const TBlob& big) { + if (req == kNullOp) return; + Shape rshape, rstride; + diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); + size_t N = small.shape_.Size(), M = rshape.Size(); + seq_reduce_compute( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride); +} + template void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 27e22491ca35..c400a7cb2170 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -636,6 +636,39 @@ void ReduceAxesComputeImpl(const OpContext& ctx, }); } +template +void ReduceAxesComputeBoolImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small) { + using namespace mshadow; + using namespace mshadow::expr; + + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + broadcast::ReduceBool( + s, out_data, req[0], workspace, in_data); + if (normalize) { + auto out = out_data.FlatTo2D(s); + out /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); + }); + }); +} + template void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 0c501808a6c0..cd0bd8deeac3 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -1244,7 +1244,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, const int b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape); } - MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch(s, vshape.FlatTo2D()[0], out.dptr(), static_cast(param.scalar), req[0], out.shape_.get(), vshape.get(), begin, step); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 5b5af8b20e36..386d634783f6 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -59,6 +59,31 @@ def get_workloads(name): return OpArgMngr._args.get(name, None) +def _add_workload_all(): + # check bad element in all positions + for i in range(256-7): + e = np.array([True] * 256, dtype=bool)[7::] + e[i] = False + OpArgMngr.add_workload('all', e) + # big array test for blocked libc loops + for i in list(range(9, 6000, 507)) + [7764, 90021, -10]: + e = np.array([True] * 100043, dtype=bool) + e[i] = False + OpArgMngr.add_workload('all', e) + + +def _add_workload_any(): + # check bad element in all positions + for i in range(256-7): + d = np.array([False] * 256, dtype=bool)[7::] + d[i] = True + OpArgMngr.add_workload('any', d) + # big array test for blocked libc loops + for i in list(range(9, 6000, 507)) + [7764, 90021, -10]: + d = np.array([False] * 100043, dtype=bool) + d[i] = True + OpArgMngr.add_workload('any', d) + def _add_workload_unravel_index(): OpArgMngr.add_workload('unravel_index', indices=np.array([2],dtype=_np.int64), shape=(2, 2)) OpArgMngr.add_workload('unravel_index', np.array([(2*3 + 1)*6 + 4], dtype=_np.int64), (4, 3, 6)) @@ -1421,6 +1446,8 @@ def _prepare_workloads(): '1x1x0': np.array([[[]]]) } + _add_workload_all() + _add_workload_any() _add_workload_argmin() _add_workload_argmax() _add_workload_around() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b39703b8ebda..42ba3df359f7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -493,6 +493,84 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_any(): + class TestAny(HybridBlock): + def __init__(self, axis=None, keepdims=False) : + super(TestAny, self).__init__() + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, a): + return F.np.any(a, axis=self._axis, keepdims=self._keepdims) + + keepdims = [True, False] + axes = [True, False] + shapes = [(), (5, ), (10, ), + (2, 5), (5, 5), (10, 10), + (4, 4, 4), (4, 6, 9), (6, 6, 6), + (7, 8, 9, 10), (7, 9, 11, 13)] + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64, np.bool] + + combinations = itertools.product([False, True], shapes, dtypes, axes, keepdims) + for hybridize, shape, dtype, axis, keepdim in combinations: + ndim = len(shape) + samples = random.randint(0, ndim) + axis = None if not axis else tuple(random.sample([i for i in range(0, ndim)], samples)) + x = np.random.normal(0, 1.0, size=shape).astype(dtype) + test_any = TestAny(axis=axis, keepdims=keepdim) + if hybridize: + test_any.hybridize() + y = test_any(x) + expected_ret = _np.any(x.asnumpy(), axis=axis, keepdims=keepdim) + assert_almost_equal(y.asnumpy(), expected_ret) + + # test imperative + mx_outs = np.any(x, axis=axis, keepdims=keepdim) + np_outs = _np.any(x.asnumpy(), axis=axis, keepdims=keepdim) + assert_almost_equal(mx_outs.asnumpy(), np_outs) + + +@with_seed() +@use_np +def test_np_all(): + class TestAll(HybridBlock): + def __init__(self, axis=None, keepdims=False) : + super(TestAll, self).__init__() + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, a): + return F.np.all(a, axis=self._axis, keepdims=self._keepdims) + + keepdims = [True, False] + axes = [True, False] + shapes = [(), (5, ), (10, ), + (2, 5), (5, 5), (10, 10), + (4, 4, 4), (4, 6, 9), (6, 6, 6), + (7, 8, 9, 10), (7, 9, 11, 13)] + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64, np.bool] + + combinations = itertools.product([False, True], shapes, dtypes, axes, keepdims) + for hybridize, shape, dtype, axis, keepdim in combinations: + ndim = len(shape) + samples = random.randint(0, ndim) + axis = None if not axis else tuple(random.sample([i for i in range(0, ndim)], samples)) + x = np.random.normal(0, 1.0, size=shape).astype(dtype) + test_all = TestAll(axis=axis, keepdims=keepdim) + if hybridize: + test_all.hybridize() + y = test_all(x) + expected_ret = _np.all(x.asnumpy(), axis=axis, keepdims=keepdim) + assert_almost_equal(y.asnumpy(), expected_ret) + + # test imperative + mx_outs = np.all(x, axis=axis, keepdims=keepdim) + np_outs = _np.all(x.asnumpy(), axis=axis, keepdims=keepdim) + assert_almost_equal(mx_outs.asnumpy(), np_outs) + + @with_seed() @use_np def test_np_max_min():