diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index b05f4e901d13..305f1c775cd5 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public dmlc::Parameter { } }; +struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter { + dmlc::optional> axis; + bool keepdims; + dmlc::optional initial; + DMLC_DECLARE_PARAMETER(NumpyReduceAxesNoDTypeParam) { + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional>()) + .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum " + "all of the elements of the input array. If axis is negative it counts from the " + "last to the first axis."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional()) + .describe("Starting value for the sum."); + } +}; + inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, const dmlc::optional>& axis, bool keepdims) { @@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, return shape_is_known(out_attrs->at(0)); } +inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const NumpyReduceAxesNoDTypeParam& param = nnvm::get(attrs.parsed); + // check the case where the reduction axis should not be zero + bool is_all_reducded_axes_not_zero = true; + const TShape& ishape = (*in_attrs)[0]; + if (param.axis.has_value()) { + const mxnet::Tuple& axes = param.axis.value(); + for (int i = 0; i < axes.ndim(); ++i) { + if (ishape[axes[i]] == 0) { + is_all_reducded_axes_not_zero = false; + break; + } + } + } else { + if (ishape.Size() == 0) { + // global reduction should excuted only when input have size more than 0 + is_all_reducded_axes_not_zero = false; + } + } + CHECK(is_all_reducded_axes_not_zero) + << "zero-size array to reduction operation maximum which has no identity"; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + template inline bool NeedSafeAcc(int itype, int otype) { bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64); @@ -186,6 +237,30 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, } } +template +void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyReduceAxesNoDTypeParam& param = nnvm::get(attrs.parsed); + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; // zero-size tensor + if (param.axis.has_value() && param.axis.value().ndim() == 0) { + UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); + } + TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); +} + + template inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -273,6 +348,24 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, } } +template +void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyReduceAxesNoDTypeParam& param = nnvm::get(attrs.parsed); + TShape small; + if (param.keepdims) { + small = inputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); + } + ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 96f0c9b350b9..bccd3af8b2cf 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -29,6 +29,7 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam); +DMLC_REGISTER_PARAMETER(NumpyReduceAxesNoDTypeParam); inline bool NumpySumType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -74,6 +75,71 @@ NNVM_REGISTER_OP(_backward_np_sum) .set_num_inputs(1) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); +inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_np_max) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxesNoDTypeShape) +.set_attr("FInferType", NumpyReduceAxesNoDTypeType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input") +.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__()) +.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ReduceGrad{"_backward_np_max"}); + +NNVM_REGISTER_OP(_backward_np_max) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_num_inputs(3) +.set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); + +NNVM_REGISTER_OP(_np_min) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxesNoDTypeShape) +.set_attr("FInferType", NumpyReduceAxesNoDTypeType) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { +return std::vector{"a"}; +}) +.add_argument("a", "NDArray-or-Symbol", "The input") +.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__()) +.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute) +.set_attr("FResourceRequest", +[](const NodeAttrs& attrs) { +return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FGradient", ReduceGrad{"_backward_np_min"}); + +NNVM_REGISTER_OP(_backward_np_min) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_num_inputs(3) +.set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); + NNVM_REGISTER_OP(_np_prod) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index a681f0d8a162..d1d33cc6d8b0 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -32,6 +32,18 @@ NNVM_REGISTER_OP(_np_sum) NNVM_REGISTER_OP(_backward_np_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); +NNVM_REGISTER_OP(_np_max) +.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); + +NNVM_REGISTER_OP(_backward_np_max) +.set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); + +NNVM_REGISTER_OP(_np_min) +.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); + +NNVM_REGISTER_OP(_backward_np_min) +.set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); + NNVM_REGISTER_OP(_np_prod) .set_attr("FCompute", NumpyReduceAxesCompute); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b8b9ca75a173..e8f3ce51bf52 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -277,6 +277,119 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_max_min(): + class TestMax(HybridBlock): + def __init__(self, axis=None, keepdims=False): + super(TestMax, self).__init__() + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.max(a, axis=self._axis, keepdims=self._keepdims) + + class TestMin(HybridBlock): + def __init__(self, axis=None, keepdims=False): + super(TestMin, self).__init__() + self._axis = axis + self._keepdims = keepdims + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.min(a, axis=self._axis, keepdims=self._keepdims) + + def is_int(dtype): + return 'int' == dtype + + def get_grad(axis, func_name): + index = -1 if func_name == 'max' else 0 + if axis == (): + return _np.ones((2,3,4,5)) + else: + temp = _np.zeros((2,3,4,5)) + if axis == 0: + temp[index,:,:,:] = 1 + return temp + elif axis == 1: + temp[:,index,:,:] = 1 + return temp + elif axis == 2: + temp[:,:,index,:] = 1 + return temp + elif axis == 3: + temp[:,:,:,index] = 1 + return temp + elif not axis: + temp[index,index,index,index] = 1 + return temp + raise ValueError('axis should be int or None or ()') + + def _test_np_exception(func, shape, dim): + x = _np.random.uniform(-1.0, 1.0, shape) + x = mx.nd.array(x).as_np_ndarray() + if func == 'max': + out = mx.np.max(x) + else: + out = mx.np.min(x) + assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim) + + in_data_dim = random.choice([2, 3, 4]) + shape = rand_shape_nd(in_data_dim, dim=3) + for func in ['max', 'min']: + for hybridize in [False, True]: + for keepdims in [True, False]: + for axis in ([i for i in range(in_data_dim)] + [(), None]): + for itype in ['float16', 'float32', 'float64', 'int']: + # test gluon + if func == 'max': + test_gluon = TestMax(axis=axis, keepdims=keepdims) + else: + test_gluon = TestMin(axis=axis, keepdims=keepdims) + if hybridize: + test_gluon.hybridize() + if is_int(itype): + x = mx.nd.arange(120).reshape((2, 3, 4, 5)) + x = mx.nd.array(x) + else: + x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) + x = x.as_np_ndarray() + x.attach_grad() + if func == 'max': + expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims) + else: + expected_ret = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims) + with mx.autograd.record(): + y = test_gluon(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3, + atol=1e-5 if itype == 'float16' else 1e-5) + y.backward() + # only check the gradient with hardcoded input + if is_int(itype): + assert same(x.grad.asnumpy(), get_grad(axis, func)), \ + 'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis)) + + # test imperative + if func == 'max': + mx_out = np.max(x, axis=axis, keepdims=keepdims) + np_out = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims) + else: + mx_out = np.min(x, axis=axis, keepdims=keepdims) + np_out = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # test zero and zero dim + shapes = [(), (0), (2, 0), (0, 2, 1)] + exceptions = [False, True, True, True] + dims = [0] * len(shapes) + for func in ['max', 'min']: + for shape, exception, dim in zip(shapes, exceptions, dims): + if exception: + assertRaises(MXNetError, _test_np_exception, func, shape, dim) + else: + _test_np_exception(func, shape, dim) + + @with_seed() @use_np def test_np_linspace():