Skip to content

Commit

Permalink
AdamW operator (Fixing Weight Decay Regularization in Adam) (apache#1…
Browse files Browse the repository at this point in the history
…3728)

* tests

* remove optimizer and move op to contrib

* rename parameter
  • Loading branch information
eric-haibin-lin authored and haohuw committed Jun 23, 2019
1 parent 68a30bd commit 948d453
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 3 deletions.
125 changes: 125 additions & 0 deletions src/operator/contrib/adamw-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2016 by Contributors
* \file optimizer_op-inl.h
* \brief Optimizer operators
* \author Haibin Lin
*/
#ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
#define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <mshadow/base.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"

namespace mxnet {
namespace op {

struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float lr;
float beta1;
float beta2;
float epsilon;
float wd;
float eta;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(AdamWParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
DMLC_DECLARE_FIELD(beta2)
.set_default(0.999f)
.describe("The decay rate for the 2nd moment estimates.");
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-8f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(eta)
.describe("Learning rate schedule multiplier");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

/*
* \brief adam_w update.
*/
template<typename xpu>
inline void AdamWUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad;
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
71 changes: 71 additions & 0 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2016 by Contributors
* \file optimizer_op.cc
* \brief Optimizer operators
* \author Haibin Lin
*/
#include "./adamw-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(AdamWParam);

NNVM_REGISTER_OP(_contrib_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).
.. math::
g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
It updates the weights using::
m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(AdamWParam::__FIELDS__());

} // namespace op
} // namespace mxnet
35 changes: 35 additions & 0 deletions src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file adamw.cu
* \brief Optimizer operators
* \author Haibin Lin
*/
#include "./adamw-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_contrib_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);

} // namespace op
} // namespace mxnet
4 changes: 1 addition & 3 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,11 @@ def test_ftml():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), lazy_update=True, **kwargs):
lazy_update=True, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
self.lazy_update = lazy_update

def create_state(self, index, weight):
Expand Down Expand Up @@ -614,7 +613,6 @@ def test_adam():
dtype, w_stype='default', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)


# AdaMax
class PyAdamax(mx.optimizer.Optimizer):
"""The python reference of AdaMax optimizer.
Expand Down

0 comments on commit 948d453

Please sign in to comment.