diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index dd818457f827..41be554953fd 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -254,7 +254,7 @@ using FNDArrayFunction = std::function& inputs, std::vector* outputs)>; /*! - * \brief Resiger a compute function for simple stateless forward only operator + * \brief Register a compute function for simple stateless forward only operator * * \note Register under "FCompute" and "FCompute" */ @@ -264,7 +264,7 @@ using FCompute = std::function& req, const std::vector& outputs)>; /*! - * \brief Resiger an NDArray compute function for simple stateless forward only operator + * \brief Register an NDArray compute function for simple stateless forward only operator * \note Register under "FComputeEx" and "FComputeEx" * Dispatched only when inferred dispatch_mode is FDispatchComputeEx */ @@ -275,7 +275,7 @@ using FComputeEx = std::function& outputs)>; /*! - * \brief Resiger a storage and dispatch mode inference function based on + * \brief Register a storage and dispatch mode inference function based on * storage types of the inputs and outputs, and the dev_mask for the operator. * * \note Register under "FInferStorageType" diff --git a/src/operator/contrib/gradient_multiplier_op.cc b/src/operator/contrib/gradient_multiplier_op.cc new file mode 100644 index 000000000000..47f891ef802b --- /dev/null +++ b/src/operator/contrib/gradient_multiplier_op.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file gradient_multiplier_op.cc + * \brief + * \author Istvan Fehervari +*/ +#include "../tensor/elemwise_unary_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const auto in_stype = in_attrs->at(0); + auto &out_stype = out_attrs->at(0); + bool dispatched = false; + if (!dispatched && (in_stype == kDefaultStorage)) { + // dense -> dense + dispatched = storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched && in_stype == kRowSparseStorage) { + // row sparse -> row sparse + dispatched = storage_type_assign(&out_stype, kRowSparseStorage, + dispatch_mode, DispatchMode::kFComputeEx); + // FComputeEx can handle dns output on cpu, too + if (dev_mask == cpu::kDevMask && out_stype == kDefaultStorage) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } + } + if (!dispatched && in_stype == kCSRStorage) { + // csr -> csr + dispatched = storage_type_assign(&out_stype, kCSRStorage, + dispatch_mode, DispatchMode::kFComputeEx); + // FComputeEx can handle dns output on cpu, too + if (dev_mask == cpu::kDevMask && out_stype == kDefaultStorage) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + +MXNET_OPERATOR_REGISTER_UNARY(_contrib_gradientmultiplier) +.describe(R"code(This operator implements the gradient multiplier function. +In forward pass it acts as an identity transform. During backpropagation it +multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to +the preceding layer. +)code" ADD_FILELINE) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) +.set_attr("FGradient", ElemwiseGradUseNone{"_contrib_backward_gradientmultiplier"}) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + return std::vector{true}; + }) +.add_argument("scalar", "float", "lambda multiplier"); + +MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BinaryScalarStorageType) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/gradient_multiplier_op.cu b/src/operator/contrib/gradient_multiplier_op.cu new file mode 100644 index 000000000000..7159cea9805d --- /dev/null +++ b/src/operator/contrib/gradient_multiplier_op.cu @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file gradient_multiplier_op.cu + * \brief + * \author Istvan Fehervari +*/ +#include "../tensor/elemwise_unary_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_gradientmultiplier) +.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) +.set_attr("FCompute", UnaryOp::IdentityCompute); + +NNVM_REGISTER_OP(_contrib_backward_gradientmultiplier) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 43d3db648a85..aac807660af1 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -261,6 +261,42 @@ def test_multibox_target_op(): assert_array_equal(loc_mask.asnumpy(), expected_loc_mask) assert_array_equal(cls_target.asnumpy(), expected_cls_target) +def test_gradient_multiplier_op(): + # We use the quadratic function in combination with gradient multiplier + def f(x, a, b, c): + return a * x**2 + b * x + c + + a = np.random.random_sample() + b = np.random.random_sample() + c = np.random.random_sample() + m = np.random.random_sample() - 0.5 + + data = mx.symbol.Variable('data') + quad_sym = mx.sym.contrib.quadratic(data=data, a=a, b=b, c=c) + gr_q_sym = mx.sym.contrib.gradientmultiplier(quad_sym, scalar=m) + + for dtype in [np.float16, np.float32, np.float64]: + for ndim in range(1, 6): + shape = rand_shape_nd(ndim, 5) + data_np = np.random.randn(*shape).astype(dtype) + expected = f(data_np, a, b, c) + backward_expected = (2 * a * data_np + b) * m + + # check imperative forward + output = mx.nd.contrib.quadratic(mx.nd.array(data_np), a=a, b=b, c=c) + output = mx.nd.contrib.gradientmultiplier(output, scalar=m) + assert_almost_equal(output.asnumpy(), expected, + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5) + # check forward + check_symbolic_forward(gr_q_sym, [data_np], [expected], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5) + # check backward + check_symbolic_backward(gr_q_sym, [data_np], [np.ones(expected.shape)], + [backward_expected], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5) if __name__ == '__main__': import nose