From bd125c31c68bfef39ac904f33719677bcabab1bf Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 8 Aug 2019 20:30:50 -0700 Subject: [PATCH] numpy-compatible sum (#15810) --- src/operator/numpy/np_broadcast_reduce_op.h | 218 ++++++++++++++++++ .../numpy/np_broadcast_reduce_op_value.cc | 78 +++++++ .../numpy/np_broadcast_reduce_op_value.cu | 36 +++ tests/python/gpu/test_operator_gpu.py | 1 + tests/python/unittest/test_numpy_op.py | 97 ++++++++ 5 files changed, 430 insertions(+) create mode 100644 src/operator/numpy/np_broadcast_reduce_op.h create mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cc create mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cu create mode 100644 tests/python/unittest/test_numpy_op.py diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h new file mode 100644 index 000000000000..c516e6b0689a --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_reduce_op.h + * \brief Function definition of broadcast and reduce operators + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ + +#include +#include +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct NumpyReduceAxesParam : public dmlc::Parameter { + dmlc::optional> axis; + dmlc::optional dtype; + bool keepdims; + dmlc::optional initial; + DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) { + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional>()) + .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum " + "all of the elements of the input array. If axis is negative it counts from the " + "last to the first axis."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("int8", mshadow::kInt8) + .add_enum("int32", mshadow::kInt32) + .add_enum("int64", mshadow::kInt64) + .set_default(dmlc::optional()) + .describe("The type of the returned array and of the accumulator in which the elements are " + "summed. The dtype of a is used by default unless a has an integer dtype of less " + "precision than the default platform integer. In that case, if a is signed then " + "the platform integer is used while if a is unsigned then an unsigned integer of " + "the same precision as the platform integer is used."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional()) + .describe("Starting value for the sum."); + } +}; + +inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, + const dmlc::optional>& axis, + bool keepdims) { + // TODO(junwu): improve the logic + // If input is a scalar, output should be a scalar too + if (ishape.ndim() == 0) { + if (axis.has_value()) { + const mxnet::Tuple& axes = axis.value(); + if (axes.ndim() > 0) { + CHECK_EQ(axes.ndim(), 1); + CHECK(axes[0] == 0 || axes[0] == -1); + } + } + return TShape(0, -1); + } + + // axis=None, do global reduction + if (!axis.has_value()) { + if (keepdims) { + return TShape(ishape.ndim(), 1); + } else { + return TShape(0, -1); + } + } + + // axis = (), will return identity(input) + if (axis.value().ndim() == 0) { + return ishape; + } + + // axis has value + mxnet::Tuple axes(axis.value()); + for (index_t i = 0; i < axes.ndim(); i++) { + if (axes[i] < 0) { + axes[i] += ishape.ndim(); + } + } + std::sort(axes.begin(), axes.end()); + + for (index_t i = 1; i < axes.ndim(); i++) { + CHECK_LT(axes[i-1], axes[i]) + << "Reduction axes have duplicates " + << axes; + } + CHECK_LT(axes[axes.ndim()-1], ishape.ndim()) + << "Reduction axis " << axes[axes.ndim()-1] + << " Exceeds input dimensions " << ishape; + CHECK_GE(axes[0], 0) + << "Reduction axis " << axis.value() + << " Exceeds input dimensions " << ishape; + + TShape oshape; + if (keepdims) { + oshape = TShape(ishape); + } else { + oshape = TShape(ishape.ndim() - axes.ndim(), -1); + } + + if (keepdims) { + for (index_t i = 0; i < axes.ndim(); ++i) { + oshape[axes[i]] = 1; + } + } else { + for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { + if (j < axes.ndim() && i == axes[j]) { + ++j; + continue; + } + oshape[k++] = ishape[i]; + } + } + return oshape; +} + +inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + +template +inline bool NeedSafeAcc(int itype, int otype) { + bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64); + return safe_acc_hint && rule; +} + +template +void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (param.axis.has_value() && param.axis.value().ndim() == 0) { + UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); + } + TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + if (NeedSafeAcc(inputs[0].type_flag_, outputs[0].type_flag_)) { + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); + } else { + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); + } +} + +template +inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + TShape small; + if (param.keepdims) { + small = inputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); + } + + BroadcastComputeImpl(attrs, ctx, inputs, req, outputs, small); + if (normalize) { + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { + Tensor igrad = outputs[0].FlatTo1D(s); + printf("output size: %lu input_size: %lu\n", outputs[0].Size(), inputs[0].Size()); + igrad /= scalar(outputs[0].Size()/inputs[0].Size()); + }); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc new file mode 100644 index 000000000000..fd227de168e1 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_reduce_op_value.cc + * \brief CPU Implementation of broadcast and reduce functions based on value. + */ + +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam); + +inline bool NumpySumType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NumpyReduceAxesParam ¶m = nnvm::get(attrs.parsed); + + if (param.dtype.has_value()) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_np_sum) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxesShape) +.set_attr("FInferType", NumpySumType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input") +.add_arguments(NumpyReduceAxesParam::__FIELDS__()) +.set_attr("FCompute", NumpyReduceAxesCompute) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_sum"}); + +NNVM_REGISTER_OP(_backward_np_sum) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_num_inputs(1) +.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu new file mode 100644 index 000000000000..1eb773009a79 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_reduce_op_value.cu + * \brief GPU Implementation of reduce functions based on value. + */ +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { +NNVM_REGISTER_OP(_np_sum) +.set_attr("FCompute", NumpyReduceAxesCompute); + +NNVM_REGISTER_OP(_backward_np_sum) +.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 0b5142dd8c8b..f8d8b4496afc 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -36,6 +36,7 @@ from common import run_in_spawned_process from test_operator import * from test_numpy_ndarray import * +from test_numpy_op import * from test_optimizer import * from test_random import * from test_exc_handling import * diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py new file mode 100644 index 000000000000..b179f67e6128 --- /dev/null +++ b/tests/python/unittest/test_numpy_op.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import absolute_import +import numpy as _np +import mxnet as mx +from mxnet import np, npx +from mxnet.base import MXNetError +from mxnet.gluon import HybridBlock +from mxnet.base import MXNetError +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray +from mxnet.test_utils import check_numeric_gradient, use_np +from common import assertRaises, with_seed +import random +import collections + + +@with_seed() +@use_np +def test_np_sum(): + class TestSum(HybridBlock): + def __init__(self, axis=None, dtype=None, keepdims=False): + super(TestSum, self).__init__() + self._axis = axis + self._dtype = dtype + self._keepdims = keepdims + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) + + def is_int(dtype): + return 'int' in dtype + + in_data_dim = random.choice([2, 3, 4]) + shape = rand_shape_nd(in_data_dim, dim=3) + acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', + 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + for hybridize in [False, True]: + for keepdims in [True, False]: + for axis in ([i for i in range(in_data_dim)] + [(), None]): + for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: + for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: + if is_int(dtype) and not is_int(itype): + continue + # test gluon + test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_sum.hybridize() + if is_int(itype): + x = _np.random.randint(-128, 128, shape, dtype=itype) + x = mx.nd.array(x) + else: + x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) + x = x.as_np_ndarray() + x.attach_grad() + expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) + expected_ret = expected_ret.astype(dtype) + with mx.autograd.record(): + y = test_sum(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) + + y.backward() + assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype)) + + # test numeric + if itype == 'float32' and dtype == 'float32': + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() + check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + + # test imperative + mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + +if __name__ == '__main__': + import nose + nose.runmodule()