From 116d01e9f2faef88010fbb98517260cbe0d2801f Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 27 Dec 2018 22:06:32 -0800 Subject: [PATCH] AdamW operator (Fixing Weight Decay Regularization in Adam) (#13728) * tests * remove optimizer and move op to contrib * rename parameter --- src/operator/contrib/adamw-inl.h | 125 ++++++++++++++++++++++++ src/operator/contrib/adamw.cc | 71 ++++++++++++++ src/operator/contrib/adamw.cu | 35 +++++++ tests/python/unittest/test_optimizer.py | 4 +- 4 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 src/operator/contrib/adamw-inl.h create mode 100644 src/operator/contrib/adamw.cc create mode 100644 src/operator/contrib/adamw.cu diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h new file mode 100644 index 000000000000..3d76b33ae765 --- /dev/null +++ b/src/operator/contrib/adamw-inl.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file optimizer_op-inl.h + * \brief Optimizer operators + * \author Haibin Lin + */ +#ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +struct AdamWParam : public dmlc::Parameter { + float lr; + float beta1; + float beta2; + float epsilon; + float wd; + float eta; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(AdamWParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-8f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(eta) + .describe("Learning rate schedule multiplier"); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +/* + * \brief adam_w update. + */ +template +inline void AdamWUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const AdamWParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + + grad = scalar(param.rescale_grad) * grad; + if (param.clip_gradient >= 0.0f) { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * + F(grad, DType(param.clip_gradient)); + var = scalar(param.beta2)*var + scalar(1.f-param.beta2)*F( + F(grad, DType(param.clip_gradient))); + } else { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; + var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); + } + Assign(out, req[0], + weight - + scalar(param.eta) * (scalar(param.lr) * + mean / (F(var) + scalar(param.epsilon)) + + (scalar(param.wd) * weight))); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc new file mode 100644 index 000000000000..94623fe08a9e --- /dev/null +++ b/src/operator/contrib/adamw.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file optimizer_op.cc + * \brief Optimizer operators + * \author Haibin Lin + */ +#include "./adamw-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(AdamWParam); + +NNVM_REGISTER_OP(_contrib_adamw_update) +.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of +Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. + +Adam update consists of the following steps, where g represents gradient and m, v +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", AdamWUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_arguments(AdamWParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu new file mode 100644 index 000000000000..b7452f861e2d --- /dev/null +++ b/src/operator/contrib/adamw.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file adamw.cu + * \brief Optimizer operators + * \author Haibin Lin + */ +#include "./adamw-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_adamw_update) +.set_attr("FCompute", AdamWUpdate); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index acf24ee1b794..eb33f9b5217e 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -506,12 +506,11 @@ def test_ftml(): class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - decay_factor=(1 - 1e-8), lazy_update=True, **kwargs): + lazy_update=True, **kwargs): super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon - self.decay_factor = decay_factor self.lazy_update = lazy_update def create_state(self, index, weight): @@ -614,7 +613,6 @@ def test_adam(): dtype, w_stype='default', g_stype='row_sparse', rtol=1e-4, atol=2e-5) - # AdaMax class PyAdamax(mx.optimizer.Optimizer): """The python reference of AdaMax optimizer.