From 08c9271eb51f1d98abd81b963c94fc3a1c918234 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 23 Jan 2019 22:32:31 -0800 Subject: [PATCH] Aggregate SGD (#13346) * Aggregate SGD * Make OpWrapperGenerator understand Tuple * Trigger * Add NNVM Tuple to cpp-package op.h * Trigger * Fix pylint aggregate SGD * Update info about new ENV vars and modifying 2 tests that require update_on_kvstore to be true * Fix * Aggregate SGD support for Gluon trainer * Added text to doc about aggregate update in SGD optimizer * Docs changes from review --- cpp-package/scripts/OpWrapperGenerator.py | 4 +- docs/faq/env_var.md | 9 + python/mxnet/gluon/trainer.py | 15 +- python/mxnet/model.py | 10 +- python/mxnet/optimizer/optimizer.py | 231 +++++++++++---- src/operator/optimizer_op-inl.h | 295 ++++++++++++++++++++ src/operator/optimizer_op.cc | 193 ++++++++++++- src/operator/optimizer_op.cu | 9 + tests/python/unittest/test_gluon_trainer.py | 8 +- tests/python/unittest/test_module.py | 3 + 10 files changed, 711 insertions(+), 66 deletions(-) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index ca430ec99e6e..65ba247c25c8 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -97,7 +97,8 @@ class Arg: 'double':'double',\ 'double or None':'dmlc::optional',\ 'Shape or None':'dmlc::optional',\ - 'string':'const std::string&'} + 'string':'const std::string&',\ + 'tuple of ':'nnvm::Tuple'} name = '' type = '' description = '' @@ -407,6 +408,7 @@ def ParseAllOps(): "#include \"mxnet-cpp/op_util.h\"\n" "#include \"mxnet-cpp/operator.h\"\n" "#include \"dmlc/optional.h\"\n" + "#include \"nnvm/tuple.h\"\n" "\n" "namespace mxnet {\n" "namespace cpp {\n" diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 98057d0d76d6..99ebae21d61f 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -145,6 +145,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - If true, MXNet tries to use GPU peer-to-peer communication, if available on your device, when kvstore's type is `device`. +* MXNET_UPDATE_ON_KVSTORE + - Values: 0(false) or 1(true) ```(default=1)``` + - If true, weight updates are performed during the communication step, if possible. + ## Memonger * MXNET_BACKWARD_DO_MIRROR @@ -218,6 +222,11 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca - When the array size is bigger than or equal to this threshold, NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP Thread Count. - When the array size is less than this threshold, NDArray::Copy(from , to)) is implemented by memcpy in single thread. +* MXNET_OPTIMIZER_AGGREGATION_SIZE + - Values: Int ```(default=4)``` + - Maximum value is 60. + - This variable controls how many weights will be updated in a single call to optimizer (for optimizers that support aggregation, currently limited to SGD). + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index f6c0a31b52e2..8060f38ac2aa 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -60,7 +60,8 @@ class Trainer(object): See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. update_on_kvstore : bool, default None Whether to perform parameter updates on kvstore. If None, then trainer will choose the more - suitable option depending on the type of kvstore. + suitable option depending on the type of kvstore. If the `update_on_kvstore` argument is + provided, environment variable `MXNET_UPDATE_ON_KVSTORE` will be ignored. Properties ---------- @@ -393,6 +394,8 @@ def update(self, batch_size, ignore_stale_grad=False): self._update(ignore_stale_grad) def _update(self, ignore_stale_grad=False): + updates = [[] for _ in self._updaters] + for i, param in enumerate(self._params): if param.grad_req == 'null': continue @@ -416,11 +419,17 @@ def _update(self, ignore_stale_grad=False): self._kvstore.pull(i, param.list_data(), priority=-i) continue - for upd, arr, grad in zip(self._updaters, param.list_data(), param.list_grad()): + for upd, arr, grad in zip(updates, param.list_data(), param.list_grad()): if not ignore_stale_grad or arr._fresh_grad: - upd(i, grad, arr) + upd.append((i, grad, arr)) arr._fresh_grad = False + if not (self._kvstore and self._update_on_kvstore): + for updater, upd in zip(self._updaters, updates): + if upd: + i, w, g = zip(*upd) + updater(i, w, g) + def save_states(self, fname): """Saves trainer states (e.g. optimizer, momentum) to a file. diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 38fe739154d5..c08077cc65f4 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -92,14 +92,14 @@ def _create_kvstore(kvstore, num_device, arg_params): arg_params : dict of str to `NDArray`. Model parameter, dict of name to `NDArray` of net's weights. """ - update_on_kvstore = True + update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', "1"))) if kvstore is None: kv = None elif isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): # create kvstore using the string type - if num_device is 1 and 'dist' not in kvstore: + if num_device == 1 and 'dist' not in kvstore: # no need to use kv for single device and single machine kv = None else: @@ -162,6 +162,7 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" + updates = [[] for _ in range(num_device)] for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: @@ -178,7 +179,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, # state for the same index but on diff devs, TODO(mli) # use a better solution later w, g = p - updater(index*num_device+k, g, w) + updates[k].append((index*num_device+k, g, w)) + for dev_updates in updates: + i, w, g = zip(*dev_updates) + updater(i, w, g) def _multiple_callbacks(callbacks, *args, **kwargs): diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 6ffbbcffc384..cb52ac54fdab 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -22,12 +22,15 @@ import math import pickle import warnings +import os import numpy from ..base import py_str from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update) + signsgd_update, signum_update, + multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, + multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -37,6 +40,8 @@ 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' ] +def _flatten_list(nested_list): + return [item for sublist in nested_list for item in sublist] class Optimizer(object): """The base class inherited by all optimizers. @@ -105,6 +110,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., self._index_update_count = {} self.clip_gradient = clip_gradient self.multi_precision = multi_precision + self.aggregate_num = 0 if param_idx2name is None: param_idx2name = {} @@ -380,13 +386,44 @@ def _update_count(self, index): Parameters ---------- - index : int + index : int or list of int The index to be updated. """ - if index not in self._index_update_count: - self._index_update_count[index] = self.begin_num_update - self._index_update_count[index] += 1 - self.num_update = max(self._index_update_count[index], self.num_update) + if not isinstance(index, (list, tuple)): + index = [index] + for idx in index: + if idx not in self._index_update_count: + self._index_update_count[idx] = self.begin_num_update + self._index_update_count[idx] += 1 + self.num_update = max(self._index_update_count[idx], self.num_update) + + def _get_lrs(self, indices): + """Gets the learning rates given the indices of the weights. + + Parameters + ---------- + indices : list of int + Indices corresponding to weights. + + Returns + ------- + lrs : list of float + Learning rates for those indices. + """ + if self.lr_scheduler is not None: + lr = self.lr_scheduler(self.num_update) + else: + lr = self.lr + + lrs = [lr for _ in indices] + for i, index in enumerate(indices): + if index in self.param_dict: + lrs[i] *= self.param_dict[index].lr_mult + elif index in self.lr_mult: + lrs[i] *= self.lr_mult[index] + elif index in self.idx2name: + lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) + return lrs def _get_lr(self, index): """Gets the learning rate given the index of the weight. @@ -401,18 +438,31 @@ def _get_lr(self, index): lr : float Learning rate for this index. """ - if self.lr_scheduler is not None: - lr = self.lr_scheduler(self.num_update) - else: - lr = self.lr + return self._get_lrs([index])[0] - if index in self.param_dict: - lr *= self.param_dict[index].lr_mult - elif index in self.lr_mult: - lr *= self.lr_mult[index] - elif index in self.idx2name: - lr *= self.lr_mult.get(self.idx2name[index], 1.0) - return lr + def _get_wds(self, indices): + """Gets weight decays for indices. + Returns 0 for non-weights if the name of weights are provided for `__init__`. + + Parameters + ---------- + indices : list of int + Indices of weights. + + Returns + ------- + wds : list of float + Weight decays for those indices. + """ + wds = [self.wd for _ in indices] + for i, index in enumerate(indices): + if index in self.param_dict: + wds[i] *= self.param_dict[index].wd_mult + elif index in self.wd_mult: + wds[i] *= self.wd_mult[index] + elif index in self.idx2name: + wds[i] *= self.wd_mult.get(self.idx2name[index], 1.0) + return wds def _get_wd(self, index): """Gets weight decay for index. @@ -421,21 +471,14 @@ def _get_wd(self, index): Parameters ---------- index : int - The index for weight. + The index of weight. Returns ------- wd : float Weight decay for this index. """ - wd = self.wd - if index in self.param_dict: - wd *= self.param_dict[index].wd_mult - elif index in self.wd_mult: - wd *= self.wd_mult[index] - elif index in self.idx2name: - wd *= self.wd_mult.get(self.idx2name[index], 1.0) - return wd + return self._get_wds([index])[0] def __getstate__(self): ret = self.__dict__.copy() @@ -471,6 +514,13 @@ class SGD(Optimizer): provides slightly different semantics than the original update, and may lead to different empirical results. + In the case when ``update_on_kvstore`` is set to False (either globally via + MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in + :class:`~mxnet.gluon.Trainer`) SGD optimizer can perform aggregated update + of parameters, which may lead to improved performance. The aggregation size + is controlled by MXNET_OPTIMIZER_AGGREGATION_SIZE environment variable and + defaults to 4. + Otherwise, **standard updates** are applied by:: rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight) @@ -502,6 +552,7 @@ def __init__(self, momentum=0.0, lazy_update=True, **kwargs): super(SGD, self).__init__(**kwargs) self.momentum = momentum self.lazy_update = lazy_update + self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) def create_state_multi_precision(self, index, weight): weight_master_copy = None @@ -522,12 +573,22 @@ def create_state(self, index, weight): momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) return momentum - def _update_impl(self, index, weight, grad, state, multi_precision=False): - assert(isinstance(weight, NDArray)) - assert(isinstance(grad, NDArray)) - self._update_count(index) - lr = self._get_lr(index) - wd = self._get_wd(index) + def _update_impl(self, indices, weights, grads, states, multi_precision=False): + aggregate = True + if not isinstance(indices, (tuple, list)): + indices = [indices] + weights = [weights] + grads = [grads] + states = [states] + for weight, grad in zip(weights, grads): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + aggregate = (aggregate and + weight.stype == 'default' and + grad.stype == 'default') + self._update_count(indices) + lrs = self._get_lrs(indices) + wds = self._get_wds(indices) kwargs = {'rescale_grad': self.rescale_grad} if self.momentum > 0: @@ -535,26 +596,49 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False): if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient - if not multi_precision: - if state is not None: - sgd_mom_update(weight, grad, state, out=weight, - lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) + if aggregate: + if not multi_precision: + if self.momentum > 0: + multi_sgd_mom_update(*_flatten_list(zip(weights, grads, states)), out=weights, + num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) + else: + multi_sgd_update(*_flatten_list(zip(weights, grads)), out=weights, + num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) else: - sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, - lr=lr, wd=wd, **kwargs) + if self.momentum > 0: + multi_mp_sgd_mom_update(*_flatten_list(zip(weights, grads, *zip(*states))), + out=weights, num_weights=len(weights), + lrs=lrs, wds=wds, **kwargs) + else: + multi_mp_sgd_update(*_flatten_list(zip(weights, grads, + list(zip(*states))[1])), + out=weights, num_weights=len(weights), + lrs=lrs, wds=wds, **kwargs) else: - if state[0] is not None: - mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, - lr=lr, wd=wd, **kwargs) - else: - mp_sgd_update(weight, grad, state[1], out=weight, - lr=lr, wd=wd, **kwargs) + for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): + if not multi_precision: + if state is not None: + sgd_mom_update(weight, grad, state, out=weight, + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) + else: + sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, + lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_sgd_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) def update(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=False) def update_multi_precision(self, index, weight, grad, state): - use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 + if not isinstance(index, (tuple, list)): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 + else: + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) @@ -1525,20 +1609,55 @@ def __init__(self, optimizer): self.optimizer = optimizer self.states = {} self.states_synced = {} + self.aggregate_updates = optimizer.aggregate_num > 0 def __call__(self, index, grad, weight): """Updates weight given gradient and index.""" - # convert ctypes.char_p.value back to python str if needed - if isinstance(index, bytes): - index = py_str(index) - if index not in self.states: - self.states[index] = self.optimizer.create_state_multi_precision(index, weight) - self.states_synced[index] = True - elif not self.states_synced[index]: - self.states[index] = \ - self.sync_state_context(self.states[index], weight.context) - self.states_synced[index] = True - self.optimizer.update_multi_precision(index, weight, grad, self.states[index]) + if not isinstance(index, (list, tuple)): + indices = [index] + grads = [grad] + weights = [weight] + else: + indices = index + grads = grad + weights = weight + for i, idx in enumerate(indices): + # convert ctypes.char_p.value back to python str if needed + if isinstance(idx, bytes): + indices[i] = py_str(idx) + idx = indices[i] + if idx not in self.states: + self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i]) + self.states_synced[idx] = True + elif not self.states_synced[idx]: + self.states[idx] = \ + self.sync_state_context(self.states[idx], weights[i].context) + self.states_synced[idx] = True + if self.aggregate_updates: + # segregate values based on type + type_map = {} + for i, w, g in zip(indices, weights, grads): + if w.dtype in type_map: + type_map[w.dtype].append((i, w, g)) + else: + type_map[w.dtype] = [(i, w, g)] + for idx in type_map: + current_index = 0 + indices, weights, grads = zip(*type_map[idx]) + while current_index < len(indices): + states = [] + step = min(self.optimizer.aggregate_num, len(indices) - current_index) + for j in range(step): + states.append(self.states[indices[current_index + j]]) + self.optimizer.update_multi_precision( + indices[current_index:current_index + self.optimizer.aggregate_num], + weights[current_index:current_index + self.optimizer.aggregate_num], + grads[current_index:current_index + self.optimizer.aggregate_num], + states) + current_index += self.optimizer.aggregate_num + else: + for i, w, g in zip(indices, weights, grads): + self.optimizer.update_multi_precision(i, w, g, self.states[i]) def sync_state_context(self, state, context): """sync state context.""" diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 9251b8614806..223a1aa6c37d 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -82,6 +82,301 @@ struct SGDParam : public dmlc::Parameter { } }; +struct MultiSGDParam : public dmlc::Parameter { + nnvm::Tuple lrs; + nnvm::Tuple wds; + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiSGDParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +struct MultiSGDMomParam : public dmlc::Parameter { + nnvm::Tuple lrs; + nnvm::Tuple wds; + float momentum; + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiSGDMomParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +template +inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + // Learning rates + CHECK_EQ(param.lrs.ndim(), param.num_weights) + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " + << param.num_weights << ", and got " << param.lrs.ndim(); + // Weight decays + CHECK_EQ(param.wds.ndim(), param.num_weights) + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " + << param.num_weights << ", and got " << param.wds.ndim(); + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + return all_inferred; +} + +template +inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < input_stride - num_fp32_inputs; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType(attrs, &input_vec, &output_vec); + } + // master copies of weights + for (int i = 0; i < param.num_weights; ++i) { + for (int j = 0; j < num_fp32_inputs; ++j) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); + } + } + return all_inferred; +} + +template +struct MultiSGDKernelParam { + static const int N = 60; + int count; + size_t max_size; + size_t sizes[N]; + DType * weights[N]; + DType * grads[N]; + MPDType * mom[N]; + MPDType * weights32[N]; + DType * out_data[N]; + MPDType lrs[N]; + MPDType wds[N]; + MPDType clip_gradient; + MPDType rescale_grad; + MPDType momentum; +}; + +template +struct MultiSGDKernel { + template + MSHADOW_XINLINE static void Map(int i, const MultiSGDKernelParam& param, + const OpReqType req) { + for (int index = 0; index < param.count; ++index) { + if ((size_t)i < param.sizes[index]) { + MPDType w = has_mixed_precision ? param.weights32[index][i] : + MPDType(param.weights[index][i]); + MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0); + if (param.clip_gradient >= 0.0f) { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index] + *mshadow_op::clip::Map(param.rescale_grad * + static_cast(param.grads[index][i]), + param.clip_gradient); + } else { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index]*param.rescale_grad*static_cast(param.grads[index][i]); + } + if (has_momentum) { + param.mom[index][i] = mom; + } + w = w + mom; + if (has_mixed_precision) { + param.weights32[index][i] = w; + } + KERNEL_ASSIGN(param.out_data[index][i], req, w); + } + } + } +}; + +template +MultiSGDKernelParam FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const ParamType& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MultiSGDKernelParam param; + param.clip_gradient = p.clip_gradient; + param.rescale_grad = p.rescale_grad; + param.momentum = 0; + param.count = p.num_weights; + param.max_size = 0; + for (int i = 0; i < param.count; ++i) { + param.sizes[i] = inputs[i * input_stride].shape_.Size(); + if (param.max_size < param.sizes[i]) { + param.max_size = param.sizes[i]; + } + param.weights[i] = inputs[i * input_stride].FlatTo2D(s).dptr_; + param.grads[i] = inputs[i * input_stride + 1].FlatTo2D(s).dptr_; + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!std::is_same::value) { + param.weights32[i] = inputs[i * input_stride + input_stride - 1] + .FlatTo2D(s).dptr_; + } + param.out_data[i] = outputs[i].FlatTo2D(s).dptr_; + param.lrs[i] = p.lrs[i]; + param.wds[i] = p.wds[i]; + } + + return param; +} + + +template +MultiSGDKernelParam FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const MultiSGDMomParam& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MultiSGDKernelParam param = + FillMultiSGDKernelParam(attrs, ctx, inputs, outputs); + param.momentum = p.momentum; + for (int i = 0; i < param.count; ++i) { + param.mom[i] = inputs[i * input_stride + 2].FlatTo2D(s).dptr_; + } + + return param; +} + +template +class type_identity { + public: + using type = T; +}; + +template +class single_precision { + public: + using type = float; +}; + +template class MPTypeChooser, int input_stride> +inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiSGDKernelParam param = + FillMultiSGDKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} + +template class MPTypeChooser, int input_stride> +inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiSGDKernelParam param = + FillMultiSGDMomKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} struct SGDKernel { template diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index a52a6f32907c..982995ad2f95 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -31,6 +31,8 @@ namespace op { DMLC_REGISTER_PARAMETER(SGDParam); DMLC_REGISTER_PARAMETER(SGDMomParam); +DMLC_REGISTER_PARAMETER(MultiSGDParam); +DMLC_REGISTER_PARAMETER(MultiSGDMomParam); DMLC_REGISTER_PARAMETER(FTMLParam); DMLC_REGISTER_PARAMETER(AdamParam); DMLC_REGISTER_PARAMETER(RMSPropParam); @@ -52,7 +54,7 @@ It updates the weights using:: weight = weight - learning_rate * sign(gradient) -.. note:: +.. note:: - sparse ndarray not supported for this optimizer yet. )code" ADD_FILELINE) .set_num_inputs(2) @@ -81,7 +83,7 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -.. note:: +.. note:: - sparse ndarray not supported for this optimizer yet. )code" ADD_FILELINE) .set_num_inputs(3) @@ -313,6 +315,193 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +NNVM_REGISTER_OP(multi_sgd_update) +.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FCompute", MultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_sgd_mom_update) +.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", MultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum") +.add_arguments(MultiSGDMomParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_update) +.describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", MP_MultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", MultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_mom_update) +.describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 4); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", MP_MultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) +.set_attr("FCompute", MultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDMomParam::__FIELDS__()); NNVM_REGISTER_OP(sgd_update) MXNET_ADD_SPARSE_OP_ALIAS(sgd_update) diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 0fd2ca83fda4..c42cf1831c43 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -242,6 +242,15 @@ NNVM_REGISTER_OP(mp_sgd_update) NNVM_REGISTER_OP(mp_sgd_mom_update) .set_attr("FCompute", MP_SGDMomUpdate); +NNVM_REGISTER_OP(multi_sgd_update) +.set_attr("FCompute", MultiSGDUpdate); +NNVM_REGISTER_OP(multi_sgd_mom_update) +.set_attr("FCompute", MultiSGDMomUpdate); +NNVM_REGISTER_OP(multi_mp_sgd_update) +.set_attr("FCompute", MultiSGDUpdate); +NNVM_REGISTER_OP(multi_mp_sgd_mom_update) +.set_attr("FCompute", MultiSGDMomUpdate); + NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 985c38c31356..9f190a0a88c2 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -17,6 +17,7 @@ import mxnet as mx import unittest +import os import numpy as np from mxnet import gluon from mxnet.gluon import nn @@ -98,6 +99,9 @@ def dict_equ(a, b): @with_seed() def test_trainer_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') + x = gluon.Parameter('x', shape=(10,), lr_mult=1.0) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}) @@ -112,6 +116,7 @@ def test_trainer_save_load(): x.lr_mult = 2.0 # check if parameter dict is correctly associated with optimizer after load_state assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) @with_seed() def test_trainer_sparse_save_load(): @@ -236,10 +241,11 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected): assert isinstance(err, expected) kvs = ['local', 'device'] + global_update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', "1"))) for kv in kvs: check_trainer_sparse_kv(kv, 'default', 'default', True, True) check_trainer_sparse_kv(kv, 'default', 'default', False, False) - check_trainer_sparse_kv(kv, 'default', 'default', None, True) + check_trainer_sparse_kv(kv, 'default', 'default', None, global_update_on_kvstore) check_trainer_sparse_kv(kv, 'default', 'row_sparse', None, False) check_trainer_sparse_kv(kv, 'default', 'row_sparse', True, True) check_trainer_sparse_kv(kv, 'default', 'row_sparse', False, False) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 7347723a39c6..d9d7175f540e 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -174,6 +174,8 @@ def test_module_layout(): @with_seed() def test_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') def dict_equ(a, b): assert set(a) == set(b) for k in a: @@ -211,6 +213,7 @@ def dict_equ(a, b): assert mod._symbol.tojson() == mod2._symbol.tojson() dict_equ(mod.get_params()[0], mod2.get_params()[0]) dict_equ(mod._kvstore._updater.states, mod2._updater.states) + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) @with_seed()