Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
nag_mp
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Acharya committed Mar 29, 2019
1 parent 645c778 commit fb7ba2a
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 18 deletions.
64 changes: 46 additions & 18 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update,
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update)
signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update,
mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update,
multi_mp_sgd_update, multi_mp_sgd_mom_update)
from ..ndarray import sparse
from ..random import normal

Expand Down Expand Up @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state):

@register
class NAG(Optimizer):
"""Nesterov accelerated SGD.
"""Nesterov accelerated gradient.
This optimizer updates each weight by::
Expand All @@ -1051,33 +1051,61 @@ def __init__(self, momentum=0.0, **kwargs):
super(NAG, self).__init__(**kwargs)
self.momentum = momentum

def create_state_multi_precision(self, index, weight):
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (self.create_state(index, weight_master_copy), weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"NAG optimizer")
return self.create_state(index, weight)

def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
return momentum

def update(self, index, weight, grad, state):
def _update_impl(self, index, weight, grad, state, multi_precision=False):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

if state is not None:
mom = state
mom[:] *= self.momentum
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] -= lr * grad
if not multi_precision:
if state is not None:
nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
else:
mp_nag_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)

def update_multi_precision(self, index, weight, grad, state):
if not isinstance(index, (tuple, list)):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)


@register
class SGLD(Optimizer):
Expand Down Expand Up @@ -1380,7 +1408,7 @@ def update(self, index, weight, grad, state):
# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
grad = clip(grad, - self.clip_gradient, self.clip_gradient)

# accumulated g and delta initlization
acc_g, acc_delta = state
Expand Down
240 changes: 240 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
}
};


template<typename ParamType, int input_stride>
inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down Expand Up @@ -1003,6 +1004,245 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
}


struct NAGParam : public dmlc::Parameter<NAGParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
KERNEL_ASSIGN(out_data[i], req,
weight_data[i]
- param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient)
+ param_wd*weight_data[i]));
} else {
KERNEL_ASSIGN(out_data[i], req,
weight_data[i]
- param_lr * (param_rescale_grad*grad_data[i] + (param_wd*weight_data[i])));
}
}
};

template<typename xpu>
inline void NAGUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const NAGParam& param = nnvm::get<NAGParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<NAGKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
});
}

struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGMomParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(momentum)
.set_default(0.0f)
.describe("The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient)
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient)));
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]));
}
}
};

template<typename xpu>
inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
});
}

template<int n_in, int n_out, int total_in>
inline bool MP_NAG_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
for (int i = n_in; i < total_in; ++i) {
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
attrs, in_attrs, out_attrs, -1);
}

struct MP_NAGKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
const DType* grad_data, float* weight32, const float param_clip_gradient,
const float param_lr, const float param_wd, const float param_rescale_grad,
const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
float w = weight32[i];
w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), param_clip_gradient)
+ param_wd*w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, (DType)w);
} else {
float w = weight32[i];
w = w - param_lr * (param_rescale_grad*static_cast<float>(grad_data[i]) + (param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, (DType)w);
}
}
};

template<typename xpu>
inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const NAGParam& param = nnvm::get<NAGParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_NAGKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
grad.dptr_, weight32.dptr_, param.clip_gradient,
param.lr, param.wd,
param.rescale_grad, req[0]);
});
}

struct MP_NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mom_data,
const DType* weight_data, const DType* grad_data, float* weight32,
const float param_clip_gradient, const float param_momentum, const float param_lr,
const float param_wd, const float param_rescale_grad, const OpReqType req) {
float w = weight32[i];
float mom = mom_data[i];
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), param_clip_gradient)
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
param_clip_gradient));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i])
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i] + param_rescale_grad*static_cast<float>(grad_data[i]));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w) ;
}
}
};

template<typename xpu>
inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_,
weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.momentum,
param.lr, param.wd, param.rescale_grad, req[0]);
});
}


struct FTMLParam : public dmlc::Parameter<FTMLParam> {
float lr;
float beta1;
Expand Down
Loading

0 comments on commit fb7ba2a

Please sign in to comment.