Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reuse sgd updates where convenient
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Acharya committed Apr 30, 2019
1 parent fc9b8ba commit 3d32e96
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 136 deletions.
10 changes: 5 additions & 5 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update,
mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update,
multi_mp_sgd_update, multi_mp_sgd_mom_update)
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update)
from ..ndarray import sparse
from ..random import normal

Expand Down Expand Up @@ -1086,13 +1086,13 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False):
if state is not None:
nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
else:
mp_nag_update(weight, grad, state[1], out=weight,
mp_sgd_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
Expand Down
87 changes: 0 additions & 87 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,48 +1029,6 @@ struct NAGParam : public dmlc::Parameter<NAGParam> {
}
};

struct NAGKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* grad_data,
const DType param_clip_gradient, const DType param_lr,
const DType param_wd, const DType param_rescale_grad,
const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
KERNEL_ASSIGN(out_data[i], req,
weight_data[i]
- param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)
+ param_wd*weight_data[i]));
} else {
KERNEL_ASSIGN(out_data[i], req,
weight_data[i]
- param_lr * (param_rescale_grad*grad_data[i]
+ (param_wd*weight_data[i])));
}
}
};

template<typename xpu>
inline void NAGUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const NAGParam& param = nnvm::get<NAGParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<NAGKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
weight.dptr_, grad.dptr_, static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
});
}

struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
float lr;
float momentum;
Expand Down Expand Up @@ -1150,51 +1108,6 @@ inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
});
}

struct MP_NAGKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* grad_data,
float* weight32, const float param_clip_gradient,
const float param_lr, const float param_wd,
const float param_rescale_grad,
const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
float w = weight32[i];
w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]), param_clip_gradient)
+ param_wd*w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, (DType)w);
} else {
float w = weight32[i];
w = w - param_lr * (param_rescale_grad
*static_cast<float>(grad_data[i]) + (param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, (DType)w);
}
}
};

template<typename xpu>
inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const NAGParam& param = nnvm::get<NAGParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_NAGKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient,
param.lr, param.wd, param.rescale_grad, req[0]);
});
}

struct MP_NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
Expand Down
37 changes: 0 additions & 37 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,24 +707,6 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
.add_arguments(AdamParam::__FIELDS__());


NNVM_REGISTER_OP(nag_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
It updates the weights using the following formula,
weight = weight - (lr * (grad + wd * weight))
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCompute>("FCompute<cpu>", NAGUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_arguments(NAGParam::__FIELDS__());


NNVM_REGISTER_OP(nag_mom_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
It updates the weights using the following formula,
Expand Down Expand Up @@ -756,25 +738,6 @@ Where
.add_arguments(NAGMomParam::__FIELDS__());


NNVM_REGISTER_OP(mp_nag_update)
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 3>)
.set_attr<FCompute>("FCompute<cpu>", MP_NAGUpdate<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "gradient")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_arguments(NAGParam::__FIELDS__());


NNVM_REGISTER_OP(mp_nag_mom_update)
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
Expand Down
6 changes: 0 additions & 6 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,9 @@ NNVM_REGISTER_OP(multi_mp_sgd_update)
NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MultiSGDMomUpdate<gpu, single_precision, 4>);

NNVM_REGISTER_OP(nag_update)
.set_attr<FCompute>("FCompute<gpu>", NAGUpdate<gpu>);

NNVM_REGISTER_OP(nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", NAGMomUpdate<gpu>);

NNVM_REGISTER_OP(mp_nag_update)
.set_attr<FCompute>("FCompute<gpu>", MP_NAGUpdate<gpu>);

NNVM_REGISTER_OP(mp_nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_NAGMomUpdate<gpu>);

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def test_nag():
if (dtype == np.float16 and ('multi_precision' not in kwarg or
not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-3, atol=1e-4)

#SGLD
class PySGLD(mx.optimizer.Optimizer):
Expand Down

0 comments on commit 3d32e96

Please sign in to comment.