Skip to content

Commit

Permalink
Signum optimizer (apache#9220)
Browse files Browse the repository at this point in the history
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
  • Loading branch information
yuxiangw authored and piiswrong committed Jan 12, 2018
1 parent 68750bc commit 5251b86
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 3 deletions.
14 changes: 14 additions & 0 deletions cpp-package/include/mxnet-cpp/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer {
AtomicSymbolCreator mom_update_handle_;
};

class SignumOptimizer : public Optimizer {
public:
explicit SignumOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~SignumOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> states_;
AtomicSymbolCreator update_handle_;
AtomicSymbolCreator mom_update_handle_;
};


class RMSPropOptimizer : public Optimizer {
public:
explicit RMSPropOptimizer(unsigned begin_num_update = 0);
Expand Down
64 changes: 64 additions & 0 deletions cpp-package/include/mxnet-cpp/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer);
auto it = cmap().find(name);
if (it == cmap().end())
return nullptr;
Expand Down Expand Up @@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
}
}

// inplementing Signum optimizer

inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("signsgd_update");
mom_update_handle_ = op_map()->GetSymbolCreator("signum_update");
}

inline std::string SignumOptimizer::GetType() const {
return "signum";
}

inline SignumOptimizer::~SignumOptimizer() {
for (auto &it : states_) {
delete it.second;
}
}

inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (states_.count(index) == 0) {
CreateState_(index, weight);
}

params_["lr"] = std::to_string(GetLR_(index));
params_["wd"] = std::to_string(GetWD_(index));
UpdateCount_(index);
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());

NDArrayHandle inputs[3];
inputs[0] = weight.GetHandle();
inputs[1] = grad.GetHandle();

int num_outputs = 1;
NDArrayHandle output = weight.GetHandle();
NDArrayHandle *outputs = &output;

if (states_[index] == nullptr) {
MXImperativeInvoke(update_handle_, 2, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
} else {
inputs[2] = states_[index]->GetHandle();
MXImperativeInvoke(mom_update_handle_, 3, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
}
}

inline void SignumOptimizer::CreateState_(int index, NDArray weight) {
if (params_.count("momentum") == 0) {
states_[index] = nullptr;
} else {
states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*states_[index] = 0;
}
}

// finish implementing Signum



inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
Expand Down
67 changes: 64 additions & 3 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from .ndarray import _internal
from .ndarray import op
from .ndarray import sparse
Expand Down Expand Up @@ -534,6 +535,67 @@ def update_multi_precision(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)

@register
class Signum(Optimizer):
"""The Signum optimizer that takes the sign of gradient or momentum.
The optimizer updates the weight by:
rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)
See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
For details of the update algorithm see
:class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
momentum : float, optional
The momentum value.
wd_lh : float, optional
The amount of decoupled weight decay regularization, see details in the original paper at:\
https://arxiv.org/abs/1711.05101
"""
def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
self.momentum = momentum
self.wd_lh = wd_lh

def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
return momentum

def _update_impl(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.wd_lh:
kwargs['wd_lh'] = self.wd_lh

if state is not None:
signum_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
signsgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state)

@register
class FTML(Optimizer):
Expand Down Expand Up @@ -702,8 +764,7 @@ def update(self, index, weight, grad, state):
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
shape=weight.shape,
ctx=weight.context)
weight.shape, weight.context)


@register # pylint: disable=invalid-name
Expand Down
142 changes: 142 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
}
};


struct SGDKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
Expand Down Expand Up @@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
}
};


struct SGDMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
Expand Down Expand Up @@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
}
}


// Implementation for signSGD and Signum

struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(SignSGDParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};


struct SignSGDKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const OpReqType req) {

// param_clip_gradient has no effect for SignSGD
KERNEL_ASSIGN(out_data[i], req,
(1.f-param_lr*param_wd)*weight_data[i]
- (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
}
};

template<typename xpu>
inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
});
}


struct SignumParam : public dmlc::Parameter<SignumParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
DMLC_DECLARE_PARAMETER(SignumParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(momentum)
.set_default(0.0f)
.describe("The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(wd_lh)
.set_default(0.0f)
.describe("The amount of weight decay that does not go into gradient/momentum calculations"
"otherwise do weight decay algorithmically only.");
}
};

struct SignumKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const DType param_wd_lh, const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
- (1-param_momentum)*param_wd*weight_data[i]
- (1-param_momentum)
*mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
} else {
mom_data[i] = param_momentum*mom_data[i]
- (1-param_momentum)*param_wd*weight_data[i]
- (1-param_momentum)*param_rescale_grad*grad_data[i];
}
KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
+ (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
}
};

template<typename xpu>
inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
});
}



} // namespace op
} // namespace mxnet

Expand Down
Loading

0 comments on commit 5251b86

Please sign in to comment.