From 9d02dd33cdb9af53707b1ab40b33978c9b8a8ba7 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Wed, 13 Mar 2019 15:44:21 -0700 Subject: [PATCH] nag_mp --- python/mxnet/optimizer/optimizer.py | 64 ++++-- src/operator/optimizer_op-inl.h | 250 +++++++++++++++++++++++- src/operator/optimizer_op.cc | 86 +++++++- src/operator/optimizer_op.cu | 12 ++ tests/python/unittest/test_optimizer.py | 28 +-- 5 files changed, 401 insertions(+), 39 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 2e7fe86c5af9..e5c5ff026112 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -28,9 +28,9 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, - multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, - multi_mp_sgd_mom_update) + signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update, + mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, + multi_mp_sgd_update, multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state): @register class NAG(Optimizer): - """Nesterov accelerated SGD. + """Nesterov accelerated gradient. This optimizer updates each weight by:: @@ -1051,33 +1051,61 @@ def __init__(self, momentum=0.0, **kwargs): super(NAG, self).__init__(**kwargs) self.momentum = momentum + def create_state_multi_precision(self, index, weight): + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = weight.astype(numpy.float32) + return (self.create_state(index, weight_master_copy), weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "NAG optimizer") + return self.create_state(index, weight) + def create_state(self, index, weight): momentum = None if self.momentum != 0.0: momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) return momentum - def update(self, index, weight, grad, state): + def _update_impl(self, index, weight, grad, state, multi_precision=False): assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - grad = grad * self.rescale_grad - if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient - if state is not None: - mom = state - mom[:] *= self.momentum - mom[:] += grad - mom[:] += wd * weight - grad[:] += self.momentum * mom - weight[:] -= lr * grad + if not multi_precision: + if state is not None: + nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) + else: + nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_nag_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state, multi_precision=False) + + def update_multi_precision(self, index, weight, grad, state): + if not isinstance(index, (tuple, list)): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 else: - assert self.momentum == 0.0 - weight[:] += -lr * (grad + wd * weight) + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 + self._update_impl(index, weight, grad, state, + multi_precision=use_multi_precision) + @register class SGLD(Optimizer): @@ -1380,7 +1408,7 @@ def update(self, index, weight, grad, state): # preprocess grad grad *= self.rescale_grad if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + grad = clip(grad, - self.clip_gradient, self.clip_gradient) # accumulated g and delta initlization acc_g, acc_delta = state diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 49eb96b9f8b2..523c5c64339d 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter { } }; + template inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -639,7 +640,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, } template -inline bool MP_SGD_InferType(const nnvm::NodeAttrs& attrs, +inline bool MP_InferType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; @@ -1003,6 +1004,253 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, } +struct NAGParam : public dmlc::Parameter { + float lr; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + const DType* weight_data, const DType* grad_data, + const DType param_clip_gradient, const DType param_lr, + const DType param_wd, const DType param_rescale_grad, + const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out_data[i], req, + weight_data[i] + - param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient) + + param_wd*weight_data[i])); + } else { + KERNEL_ASSIGN(out_data[i], req, + weight_data[i] + - param_lr * (param_rescale_grad*grad_data[i] + + (param_wd*weight_data[i]))); + } + } +}; + +template +inline void NAGUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const NAGParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), req[0]); + }); +} + +struct NAGMomParam : public dmlc::Parameter { + float lr; + float momentum; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGMomParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, + const DType* weight_data, const DType* grad_data, + const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, + const DType param_rescale_grad, const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient) + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient))); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i] + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i])); + } + } +}; + +template +inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, + static_cast(param.clip_gradient), + static_cast(param.momentum), static_cast(param.lr), + static_cast(param.wd), static_cast(param.rescale_grad), + req[0]); + }); +} + +struct MP_NAGKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + const DType* weight_data, const DType* grad_data, + float* weight32, const float param_clip_gradient, + const float param_lr, const float param_wd, + const float param_rescale_grad, + const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + float w = weight32[i]; + w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), param_clip_gradient) + + param_wd*w); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, (DType)w); + } else { + float w = weight32[i]; + w = w - param_lr * (param_rescale_grad + *static_cast(grad_data[i]) + (param_wd*w)); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, (DType)w); + } + } +}; + +template +inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const NAGParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor weight32 = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, + param.lr, param.wd, param.rescale_grad, req[0]); + }); +} + +struct MP_NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + float* mom_data, const DType* weight_data, + const DType* grad_data, float* weight32, + const float param_clip_gradient, + const float param_momentum, const float param_lr, + const float param_wd, const float param_rescale_grad, + const OpReqType req) { + float w = weight32[i]; + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), param_clip_gradient) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), + param_clip_gradient)); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i]) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i])); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } + } +}; + +template +inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor weight32 = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, + param.clip_gradient, param.momentum, param.lr, param.wd, + param.rescale_grad, req[0]); + }); +} + + struct FTMLParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 367b91b2646c..859fa58638b0 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -35,6 +35,8 @@ DMLC_REGISTER_PARAMETER(MultiSGDParam); DMLC_REGISTER_PARAMETER(MultiSGDMomParam); DMLC_REGISTER_PARAMETER(FTMLParam); DMLC_REGISTER_PARAMETER(AdamParam); +DMLC_REGISTER_PARAMETER(NAGParam); +DMLC_REGISTER_PARAMETER(NAGMomParam); DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); @@ -590,7 +592,7 @@ NNVM_REGISTER_OP(mp_sgd_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 3>) +.set_attr("FInferType", MP_InferType<2, 1, 3>) .set_attr("FCompute", MP_SGDUpdate) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { @@ -607,7 +609,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<4, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 4>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; @@ -705,6 +707,86 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_arguments(AdamParam::__FIELDS__()); +NNVM_REGISTER_OP(nag_update) +MXNET_ADD_SPARSE_OP_ALIAS(nag_update) +.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +NAG update consists of the following steps, + +state = momentum * state + grad + wd * weight +weight = weight - (lr * (grad + momentum * state)) +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", NAGUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_arguments(NAGParam::__FIELDS__()); + + +NNVM_REGISTER_OP(nag_mom_update) +MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) +.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FCompute", NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_arguments(NAGMomParam::__FIELDS__()); + + +NNVM_REGISTER_OP(mp_nag_update) +MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) +.describe(R"code(Multi-precision NAG update. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", MP_InferType<2, 1, 3>) +.set_attr("FCompute", MP_NAGUpdate) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "gradient") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(NAGParam::__FIELDS__()); + + +NNVM_REGISTER_OP(mp_nag_mom_update) +MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_mom_update) +.describe(R"code(Multi-precision NAG update. +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", MP_NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(NAGMomParam::__FIELDS__()); + + NNVM_REGISTER_OP(rmsprop_update) .describe(R"code(Update function for `RMSProp` optimizer. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index c42cf1831c43..5d361b7e528d 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -251,6 +251,18 @@ NNVM_REGISTER_OP(multi_mp_sgd_update) NNVM_REGISTER_OP(multi_mp_sgd_mom_update) .set_attr("FCompute", MultiSGDMomUpdate); +NNVM_REGISTER_OP(nag_update) +.set_attr("FCompute", NAGUpdate); + +NNVM_REGISTER_OP(nag_mom_update) +.set_attr("FCompute", NAGMomUpdate); + +NNVM_REGISTER_OP(mp_nag_update) +.set_attr("FCompute", MP_NAGUpdate); + +NNVM_REGISTER_OP(mp_nag_mom_update) +.set_attr("FCompute", MP_NAGMomUpdate); + NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index d5aabcb4b1e5..e151cfde2306 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -346,7 +346,7 @@ def create_state(self, index, weight): if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=np.float32) weight_master_copy = array(weight, ctx=weight.context, dtype=np.float32) - return (weight_master_copy, momentum) + return (momentum, weight_master_copy) else: if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) @@ -394,8 +394,8 @@ def update(self, index, weight, grad, state): grad32 = grad32 * self.rescale_grad if self.clip_gradient is not None: grad32 = mx.nd.clip(grad32, -self.clip_gradient, self.clip_gradient) - mom = state[1] - weight32 = state[0] + mom = state[0] + weight32 = state[1] if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: @@ -417,23 +417,15 @@ def test_nag(): rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + for dtype in [np.float16, np.float32, np.float64]: - for mom_option in mom_options: - for cg_option in cg_options: - for rg_option in rg_options: - for wd_option in wd_options: - for mp_option in mp_options: - kwarg = {} - kwarg.update(mom_option) - kwarg.update(cg_option) - kwarg.update(rg_option) - kwarg.update(wd_option) - kwarg.update(mp_option) - if (dtype == np.float16 and - ('multi_precision' not in kwarg or + for params in itertools.product(mom_options, cg_options, rg_options, + wd_options, mp_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): - continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + continue + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) #SGLD class PySGLD(mx.optimizer.Optimizer):