From f987bc4e18aa2087ece8d2d58e73180600f71d28 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 01:40:49 +0000 Subject: [PATCH 01/12] add clarification for param_dict --- python/mxnet/gluon/trainer.py | 177 +++++++++++++------- python/mxnet/model.py | 5 + python/mxnet/optimizer/optimizer.py | 22 ++- tests/python/unittest/test_gluon_trainer.py | 4 +- 4 files changed, 138 insertions(+), 70 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index c4d49e82c908..e3e52db1ce3e 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -20,6 +20,7 @@ """Parameter optimizer.""" __all__ = ['Trainer'] +import copy from .. import optimizer as opt from ..model import _create_kvstore, _create_sparse_kvstore from .parameter import ParameterDict, Parameter @@ -53,6 +54,12 @@ class Trainer(object): Whether to perform parameter updates on kvstore. If None, then trainer will choose the more suitable option depending on the type of kvstore. + Notes + ----- + `update_on_kvstore=False` is not supported in the following cases: + - dist kvstore with sparse weights or sparse gradients + - dist async kvstore + Properties ---------- learning_rate : float @@ -87,7 +94,13 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', self._compression_params = compression_params optimizer_params = optimizer_params if optimizer_params else {} self._scale = float(optimizer_params.get('rescale_grad', 1.0)) + # one optimizer / updater per context + # If self._update_on_kvstore is set to `True` in `_init_kvstore()`, then: + # - updaters[:] are never used. + # - optimizer[0] is registered with kvstore. optimizer[1:] are never used. self._contexts = self._check_contexts() + self._optimizers = [] + self._updaters = [] self._init_optimizer(optimizer, optimizer_params) self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore} self._kv_initialized = False @@ -114,14 +127,18 @@ def _init_optimizer(self, optimizer, optimizer_params): assert not optimizer_params, \ "optimizer_params must be None if optimizer is an instance of " \ "Optimizer instead of str" - self._optimizer = optimizer - self._optimizer.param_dict = param_dict else: - self._optimizer = opt.create(optimizer, param_dict=param_dict, - **optimizer_params) - - self._updaters = [opt.get_updater(self._optimizer) \ - for _ in self._contexts] + optimizer = opt.create(optimizer, **optimizer_params) + optimizer.param_dict = param_dict + self._optimizers = [optimizer] + self._updaters = [opt.get_updater(optimizer)] + # create a deep copy of the optimizer per context + for _ in range(len(self._contexts) - 1): + optim = copy.deepcopy(optimizer) + # param_dict must not be deep copied + optim.param_dict = param_dict + self._optimizers.append(optim) + self._updaters.append(opt.get_updater(optim)) def _init_params(self): """Initialize parameters in the KVStore. @@ -158,57 +175,74 @@ def _reset_kvstore(self): def _init_kvstore(self): """Create kvstore.""" config = self._kvstore_params - # if weight is sparse, the weight must be updated on KVStore. - # training loop contains: - # - row_sparse_pull(sparse_weight) - # - forward() - # - backward() - # - push(sparse_grad), push(dense_grad) - # - pull(dense_weight) + # configure kvstore, update_on_kvstore and self._distributed on three cases: if self._contains_sparse_weight: + # If weight is sparse, kvstore must be present and the weight must be updated on kvstore. + # The training loop is the following: + # - row_sparse_pull(sparse_weight) + # - forward() + # - backward() + # - push_and_update(grad) + # - pull(weight) kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) - # raise Error if update_on_kvstore is set to False by the user + self._distributed = 'dist' in kvstore.type + # raise err if user provides unsupported configs if config['update_on_kvstore'] is False: - raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights " - "are present.") - # if weight is dense and grad is sparse, the weight better not be updated on KVStore. - # training loop contains: - # - forward() - # - backward() - # - push(grad) - # - pull(grad) - # - update(grad, weight) + raise ValueError("Cannot set update_on_kvstore=False when sparse weights " + "are present.") + elif self._contains_sparse_grad: + # For single node training with dense weight and sparse grad, + # we prefer update_on_kvstore=False because this is usually faster. + # This means we push and pull sparse gradients, and we do not store weight in kvstore. + # The training loop is the following: + # - forward() + # - backward() + # - push(grad) + # - pull(grad) + # - update(grad, weight) + # + # For multi-node training with dense weight and sparse grad, + # only update_on_kvstore=True is supported, due to the fact that + # kv.row_sparse_pull(grad) is not implemented. + # Therefore, we push sparse gradients and pull dense weights. + # The training loop contains: + # - forward() + # - backward() + # - push_and_update(grad) + # - pull(weight) arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) - update_on_kvstore = False - # normal case + self._distributed = 'dist' in kvstore.type if kvstore else False + update_on_kvstore = self._distributed + # raise err if user provides unsupported configs + if config['update_on_kvstore'] is False and self._distributed: + raise RuntimeError("Cannot set update_on_kvstore=False on dist kvstore " + "when sparse gradients are present.") + else: + # Training with dense weight and dense gradients. + # The only unsupported mode is async with update_on_kvstore=False arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) - if kvstore and 'async' in kvstore.type and config['update_on_kvstore'] is not None\ - and not config['update_on_kvstore']: - raise ValueError("Please set update_on_kvstore to true " - "when training in async mode.") - + self._distributed = 'dist' in kvstore.type if kvstore else False + if self._distributed and 'async' in kvstore.type: + update_on_kvstore = True + # raise err if user provides unsupported configs + if config['update_on_kvstore'] is False: + raise ValueError("Please set update_on_kvstore=True " + "when training in async mode.") if config['update_on_kvstore'] is not None: update_on_kvstore = config['update_on_kvstore'] + # set grad compression and optimizers if kvstore: if self._compression_params: kvstore.set_gradient_compression(self._compression_params) - self._distributed = 'dist' in kvstore.type - if self._distributed: - # kv.pull(row_sparse_grad) is not supported for dist kvstore - # Captures condition for dist_async, dist_device_sync or based on config for - # update_on_kvstore - update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \ - or 'device' in kvstore.type or 'async' in kvstore.type \ - or config['update_on_kvstore'] if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision - kvstore.set_optimizer(self._optimizer) + kvstore.set_optimizer(self._optimizers[0]) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore else: @@ -219,11 +253,16 @@ def _init_kvstore(self): @property def learning_rate(self): - if not isinstance(self._optimizer, opt.Optimizer): + if not isinstance(self._optimizers[0], opt.Optimizer): raise UserWarning("Optimizer has to be defined before its learning " "rate can be accessed.") else: - return self._optimizer.learning_rate + lr = self._optimizers[0].learning_rate + for i in range(self._contexts): + if self._optimizers[i].learning_rate != lr: + raise UserWarning("The optimizer on %s has a different learning rate" + " from that on %s. Cannot return learning rate") + return lr def set_learning_rate(self, lr): """Sets a new learning rate of the optimizer. @@ -233,11 +272,14 @@ def set_learning_rate(self, lr): lr : float The new learning rate of the optimizer. """ - if not isinstance(self._optimizer, opt.Optimizer): + if not self._optimizers: raise UserWarning("Optimizer has to be defined before its learning " "rate is mutated.") - else: - self._optimizer.set_learning_rate(lr) + for optim in self._optimizers: + if not isinstance(optim, opt.Optimizer): + raise UserWarning("Optimizer has to be defined before its learning " + "rate is mutated.") + optim.set_learning_rate(lr) def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): """Internal method to invoke pull operations on KVStore. If `full_idx` is set to True, @@ -255,6 +297,17 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): else: self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) + def _check_and_rescale_grad(self, scale): + for optim in self._optimizers: + if self._update_on_kvstore and self._distributed and self._kv_initialized: + if optim.rescale_grad != scale: + raise UserWarning('Possible change in the `batch_size` from previous ' + '`step` detected. Optimizer gradient normalizing ' + 'factor will not change w.r.t new batch_size when ' + 'update_on_kvstore=True and when distributed kvstore ' + 'is used.') + optim.rescale_grad = scale + def step(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. Should be called after `autograd.backward()` and outside of `record()` scope. @@ -274,13 +327,7 @@ def step(self, batch_size, ignore_stale_grad=False): been updated by `backward` after last step) and skip update. """ rescale_grad = self._scale / batch_size - if self._update_on_kvstore and self._distributed and \ - self._optimizer.rescale_grad != rescale_grad: - raise UserWarning('Possible change in the `batch_size` from previous `step` detected.' \ - 'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \ - 'update_on_kvstore=True and when distributed `kvstore` is used.') - - self._optimizer.rescale_grad = rescale_grad + self._check_and_rescale_grad(rescale_grad) if not self._kv_initialized: self._init_kvstore() @@ -352,7 +399,7 @@ def update(self, batch_size, ignore_stale_grad=False): 'is not supported. Try setting `update_on_kvstore` ' \ 'to False when creating trainer.' - self._optimizer.rescale_grad = self._scale / batch_size + self._check_and_rescale_grad(self._scale / batch_size) self._update(ignore_stale_grad) def _update(self, ignore_stale_grad=False): @@ -387,12 +434,18 @@ def _update(self, ignore_stale_grad=False): def save_states(self, fname): """Saves trainer states (e.g. optimizer, momentum) to a file. + Parameters ---------- fname : str Path to output states file. + + Note + ---- + `optimizer.param_dict`, which contains Parameter information (such as + `lr_mult` and `wd_mult`) will not be saved. """ - assert self._optimizer is not None + assert self._optimizers and self._optimizers[0] is not None if not self._kv_initialized: self._init_kvstore() @@ -414,6 +467,12 @@ def load_states(self, fname): ---------- fname : str Path to input states file. + + Note + ---- + `optimizer.param_dict`, which contains Parameter information (such as + `lr_mult` and `wd_mult`) will not be loaded from the file, but rather set + based on current Trainer's parameters. """ if not self._kv_initialized: self._init_kvstore() @@ -422,13 +481,13 @@ def load_states(self, fname): if self._update_on_kvstore: self._kvstore.load_optimizer_states(fname) - self._optimizer = self._kvstore._updater.optimizer - param_dict = {i: param for i, param in enumerate(self._params)} - self._optimizer.param_dict = param_dict + optimizer = self._kvstore._updater.optimizer + self._init_optimizer(optimizer, None) else: with open(fname, 'rb') as f: states = f.read() + param_dict = {i: param for i, param in enumerate(self._params)} for updater in self._updaters: updater.set_states(states) - updater.optimizer = self._updaters[0].optimizer - self._optimizer = self._updaters[0].optimizer + updater.optimizer.param_dict = param_dict + self._optimizers = [updater.optimizer for updater in self._updaters] diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 2666f8bbcd4f..38fe739154d5 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -62,6 +62,11 @@ def _create_sparse_kvstore(kvstore): ---------- kvstore : KVStore or str The kvstore. + + Returns + ------- + kvstore : KVStore + update_on_kvstore : bool. Always True. """ # always update on kvstore update_on_kvstore = True diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index a085b6fe2ef6..ba16132ab084 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -43,33 +43,33 @@ class Optimizer(object): Parameters ---------- - rescale_grad : float, optional + rescale_grad : float, optional, default 1.0 Multiply the gradient with `rescale_grad` before updating. Often choose to be ``1.0/batch_size``. - param_idx2name : dict from int to string, optional + param_idx2name : dict from int to string, optional, default None A dictionary that maps int index to string name. - clip_gradient : float, optional + clip_gradient : float, optional, default None Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``. - learning_rate : float, optional + learning_rate : float, optional, default 0.01 The initial learning rate. - lr_scheduler : LRScheduler, optional + lr_scheduler : LRScheduler, optional, default None The learning rate scheduler. - wd : float, optional + wd : float, optional, default 0.0 The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having large weights. - sym: Symbol, optional + sym: Symbol, optional, default None The Symbol this optimizer is applying to. - begin_num_update : int, optional + begin_num_update : int, optional, default 0 The initial number of updates. - multi_precision : bool, optional + multi_precision : bool, optional, default False Flag to control the internal precision of the optimizer.:: False: results in using the same precision as the weights (default), @@ -77,6 +77,10 @@ class Optimizer(object): in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16. + param_dict : dict of int -> gluon.Parameter, default None + Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes + such as lr_mult, wd_mult, etc. param_dict shall not be deep copied. + Properties ---------- learning_rate : float diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 72c01acb2652..62146824d9ca 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -73,11 +73,11 @@ def dict_equ(a, b): trainer.load_states('test_trainer.states') if trainer._update_on_kvstore: dict_equ(trainer._kvstore._updater.states, states) - assert trainer._optimizer == trainer._kvstore._updater.optimizer + assert trainer._optimizers[0] == trainer._kvstore._updater.optimizer else: for updater in trainer._updaters: dict_equ(updater.states, states) - assert trainer._optimizer == trainer._updaters[0].optimizer + assert trainer._optimizers[0] == trainer._updaters[0].optimizer assert_raises(AssertionError, trainer.update, 1) assert_raises(AssertionError, trainer.allreduce_grads) From 2186c493582e20229c8692a7869bccf50501b4bf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 05:26:08 +0000 Subject: [PATCH 02/12] more tests for dist kvstore --- tests/nightly/dist_async_kvstore.py | 18 +++++++++++------- tests/nightly/dist_sync_kvstore.py | 15 +++++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/nightly/dist_async_kvstore.py b/tests/nightly/dist_async_kvstore.py index 3e400eafa045..b990b6b3f13e 100644 --- a/tests/nightly/dist_async_kvstore.py +++ b/tests/nightly/dist_async_kvstore.py @@ -27,22 +27,26 @@ nworker = kv.num_workers def test_gluon_trainer_type(): - def check_trainer_kv_update(update_on_kv): + def check_trainer_kv_update(weight_stype, update_on_kv): params = mx.gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0) + x = params.get('x', shape=(10,1), lr_mult=1.0, stype=weight_stype) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') try: - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + kvstore=kv, update_on_kvstore=update_on_kv) trainer._init_kvstore() assert trainer._kv_initialized assert trainer._update_on_kvstore is True except ValueError: assert update_on_kv is False - check_trainer_kv_update(False) - check_trainer_kv_update(True) - check_trainer_kv_update(None) + check_trainer_kv_update('default', False) + check_trainer_kv_update('default', True) + check_trainer_kv_update('default', None) + check_trainer_kv_update('row_sparse', False) + check_trainer_kv_update('row_sparse', True) + check_trainer_kv_update('row_sparse', None) print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type') if __name__ == "__main__": - test_gluon_trainer_type() \ No newline at end of file + test_gluon_trainer_type() diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 861b85913ac8..c94dace3e31e 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -376,18 +376,21 @@ def check_invalid_pull(): check_invalid_pull() def test_gluon_trainer_type(): - def check_trainer_kv_type(stype, grad_stype, update_on_kv): + def check_trainer_kv_type(stype, grad_stype, update_on_kv, expected_update_on_kv): params = mx.gluon.ParameterDict() x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + kvstore=kv, update_on_kvstore=update_on_kv) trainer._init_kvstore() assert trainer._kv_initialized - assert trainer._update_on_kvstore is update_on_kv + assert trainer._update_on_kvstore is expected_update_on_kv - check_trainer_kv_type('default', 'default', False) - check_trainer_kv_type('default', 'row_sparse', True) - check_trainer_kv_type('row_sparse', 'row_sparse', True) + check_trainer_kv_type('default', 'default', None, True) + check_trainer_kv_type('default', 'default', True, True) + check_trainer_kv_type('default', 'default', False, False) + check_trainer_kv_type('default', 'row_sparse', None, True) + check_trainer_kv_type('row_sparse', 'row_sparse', None, True) print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type') def test_gluon_trainer_step(): From b97e8719efe46509fc22d896c6cfc4105ac2a23a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 05:41:45 +0000 Subject: [PATCH 03/12] more unittests --- python/mxnet/gluon/trainer.py | 3 ++- tests/python/unittest/test_gluon_trainer.py | 28 +++++++++++++++------ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index e3e52db1ce3e..a712700b1f8c 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -135,7 +135,8 @@ def _init_optimizer(self, optimizer, optimizer_params): # create a deep copy of the optimizer per context for _ in range(len(self._contexts) - 1): optim = copy.deepcopy(optimizer) - # param_dict must not be deep copied + # param_dict must not be deep copied, so that if user mutate the lr_mult + # or wd_mult of some parameters, it takes effect. optim.param_dict = param_dict self._optimizers.append(optim) self._updaters.append(opt.get_updater(optim)) diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 7f636480c579..7fb1f1099c0b 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -55,16 +55,17 @@ def dict_equ(a, b): y.backward() trainer.step(1) + assert len(trainer._optimizers) == 2 + assert len(trainer._updaters) == 2 + assert trainer._optimizers[0].param_dict == trainer._optimizers[1].param_dict assert (x.data(mx.cpu(1)).asnumpy() == -2).all() x.lr_mult = 0.5 - with mx.autograd.record(): for w in x.list_data(): y = w + 1 y.backward() trainer.step(1) - assert (x.data(mx.cpu(1)).asnumpy() == -4).all() trainer.save_states('test_trainer.states') @@ -74,6 +75,8 @@ def dict_equ(a, b): if trainer._update_on_kvstore: dict_equ(trainer._kvstore._updater.states, states) assert trainer._optimizers[0] == trainer._kvstore._updater.optimizer + assert len(trainer._optimizers) == 2 + assert len(trainer._updaters) == 2 # invalid usage of update and allreduce_grads if update_on_kvstore assert_raises(AssertionError, trainer.update, 1) assert_raises(AssertionError, trainer.allreduce_grads) @@ -81,6 +84,8 @@ def dict_equ(a, b): for updater in trainer._updaters: dict_equ(updater.states, states) assert trainer._optimizers[0] == trainer._updaters[0].optimizer + assert len(trainer._optimizers) == 2 + assert len(trainer._updaters) == 2 x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') @@ -212,11 +217,12 @@ def check_trainer_reset_kv(kv): @with_seed() def test_trainer_sparse_kv(): - def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): + def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected_update_on_kv): params = gluon.ParameterDict() x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + kvstore=kv, update_on_kvstore=update_on_kv) all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows) with mx.autograd.record(): @@ -226,7 +232,7 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): trainer.step(1) assert trainer._kvstore.type == kv assert trainer._kv_initialized - assert trainer._update_on_kvstore is update_on_kv + assert trainer._update_on_kvstore is expected_update_on_kv # the updated parameter should be based on the loaded checkpoint mx.nd.waitall() updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows) @@ -234,6 +240,12 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): kvs = ['local', 'device'] for kv in kvs: - check_trainer_sparse_kv(kv, 'default', 'default', True) - check_trainer_sparse_kv(kv, 'default', 'row_sparse', False) - check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', True) + check_trainer_sparse_kv(kv, 'default', 'default', True, True) + check_trainer_sparse_kv(kv, 'default', 'default', False, False) + check_trainer_sparse_kv(kv, 'default', 'default', None, True) + check_trainer_sparse_kv(kv, 'default', 'row_sparse', None, False) + check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', None, True) + +@with_seed() +def test_trainer_lr_scheduler(): + pass From 9be416251925a9ec0b41fee91ff4267951544f95 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 05:54:09 +0000 Subject: [PATCH 04/12] fix a bug --- python/mxnet/gluon/trainer.py | 8 +++-- tests/python/unittest/test_gluon_trainer.py | 34 ++++++++++++--------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index a712700b1f8c..00ef28fb3c7e 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -217,9 +217,11 @@ def _init_kvstore(self): self._distributed = 'dist' in kvstore.type if kvstore else False update_on_kvstore = self._distributed # raise err if user provides unsupported configs - if config['update_on_kvstore'] is False and self._distributed: - raise RuntimeError("Cannot set update_on_kvstore=False on dist kvstore " - "when sparse gradients are present.") + if config['update_on_kvstore'] is not None: + if config['update_on_kvstore'] is False and self._distributed: + raise ValueError("Cannot set update_on_kvstore=False on dist kvstore " + "when sparse gradients are present.") + update_on_kvstore = config['update_on_kvstore'] else: # Training with dense weight and dense gradients. diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 7fb1f1099c0b..97ca8dcd3f17 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -217,26 +217,29 @@ def check_trainer_reset_kv(kv): @with_seed() def test_trainer_sparse_kv(): - def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected_update_on_kv): + def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected): params = gluon.ParameterDict() x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) - ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows) - with mx.autograd.record(): - for w in ws: - y = w + 1 - y.backward() - trainer.step(1) - assert trainer._kvstore.type == kv - assert trainer._kv_initialized - assert trainer._update_on_kvstore is expected_update_on_kv - # the updated parameter should be based on the loaded checkpoint - mx.nd.waitall() - updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows) - assert (updated_w == -0.2).asnumpy().all() + try: + ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows) + with mx.autograd.record(): + for w in ws: + y = w + 1 + y.backward() + trainer.step(1) + assert trainer._kvstore.type == kv + assert trainer._kv_initialized + assert trainer._update_on_kvstore is expected + # the updated parameter should be based on the loaded checkpoint + mx.nd.waitall() + updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows) + assert (updated_w == -0.2).asnumpy().all() + except Exception as err: + assert isinstance(err, expected) kvs = ['local', 'device'] for kv in kvs: @@ -244,7 +247,10 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected_update check_trainer_sparse_kv(kv, 'default', 'default', False, False) check_trainer_sparse_kv(kv, 'default', 'default', None, True) check_trainer_sparse_kv(kv, 'default', 'row_sparse', None, False) + check_trainer_sparse_kv(kv, 'default', 'row_sparse', True, True) + check_trainer_sparse_kv(kv, 'default', 'row_sparse', False, False) check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', None, True) + check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', False, ValueError) @with_seed() def test_trainer_lr_scheduler(): From de6bd0c8a3b08ca41c9e208de27ca60d5895a696 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 06:00:52 +0000 Subject: [PATCH 05/12] more dist exception test --- tests/nightly/dist_sync_kvstore.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index c94dace3e31e..4523a361cf88 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -376,21 +376,26 @@ def check_invalid_pull(): check_invalid_pull() def test_gluon_trainer_type(): - def check_trainer_kv_type(stype, grad_stype, update_on_kv, expected_update_on_kv): + def check_trainer_kv_type(stype, grad_stype, update_on_kv, expected): params = mx.gluon.ParameterDict() x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) - trainer._init_kvstore() - assert trainer._kv_initialized - assert trainer._update_on_kvstore is expected_update_on_kv + try: + trainer._init_kvstore() + assert trainer._kv_initialized + assert trainer._update_on_kvstore is expected + except Exception as err: + assert isinstance(err, expected) check_trainer_kv_type('default', 'default', None, True) check_trainer_kv_type('default', 'default', True, True) check_trainer_kv_type('default', 'default', False, False) check_trainer_kv_type('default', 'row_sparse', None, True) + check_trainer_kv_type('default', 'row_sparse', False, ValueError) check_trainer_kv_type('row_sparse', 'row_sparse', None, True) + check_trainer_kv_type('row_sparse', 'row_sparse', False, ValueError) print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type') def test_gluon_trainer_step(): From 634bca2698d9db6a1079ab93800c75cf40493cc4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 06:45:11 +0000 Subject: [PATCH 06/12] revert optimizer list --- python/mxnet/gluon/trainer.py | 78 +++++++++------------ tests/python/unittest/test_gluon_trainer.py | 29 +++++--- 2 files changed, 51 insertions(+), 56 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 00ef28fb3c7e..45f8dd546f3d 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -59,6 +59,7 @@ class Trainer(object): `update_on_kvstore=False` is not supported in the following cases: - dist kvstore with sparse weights or sparse gradients - dist async kvstore + - optimizer.lr_scheduler is not None Properties ---------- @@ -94,13 +95,7 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', self._compression_params = compression_params optimizer_params = optimizer_params if optimizer_params else {} self._scale = float(optimizer_params.get('rescale_grad', 1.0)) - # one optimizer / updater per context - # If self._update_on_kvstore is set to `True` in `_init_kvstore()`, then: - # - updaters[:] are never used. - # - optimizer[0] is registered with kvstore. optimizer[1:] are never used. self._contexts = self._check_contexts() - self._optimizers = [] - self._updaters = [] self._init_optimizer(optimizer, optimizer_params) self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore} self._kv_initialized = False @@ -127,19 +122,15 @@ def _init_optimizer(self, optimizer, optimizer_params): assert not optimizer_params, \ "optimizer_params must be None if optimizer is an instance of " \ "Optimizer instead of str" - else: - optimizer = opt.create(optimizer, **optimizer_params) - optimizer.param_dict = param_dict - self._optimizers = [optimizer] - self._updaters = [opt.get_updater(optimizer)] - # create a deep copy of the optimizer per context - for _ in range(len(self._contexts) - 1): - optim = copy.deepcopy(optimizer) + self._optimizer = optimizer # param_dict must not be deep copied, so that if user mutate the lr_mult # or wd_mult of some parameters, it takes effect. - optim.param_dict = param_dict - self._optimizers.append(optim) - self._updaters.append(opt.get_updater(optim)) + self._optimizer.param_dict = param_dict + else: + self._optimizer = opt.create(optimizer, param_dict=param_dict, + **optimizer_params) + self._updaters = [opt.get_updater(self._optimizer) \ + for _ in self._contexts] def _init_params(self): """Initialize parameters in the KVStore. @@ -245,9 +236,13 @@ def _init_kvstore(self): kvstore.set_gradient_compression(self._compression_params) if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision - kvstore.set_optimizer(self._optimizers[0]) + kvstore.set_optimizer(self._optimizer) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore + if self._optimizer.lr_scheduler is not None: + assert self._update_on_kvstore, "update_on_kvstore=True does not support " \ + "optimizer with LRScheduler. Please " \ + "consider setting learning rate manually." else: self._kvstore = None self._update_on_kvstore = None @@ -256,16 +251,11 @@ def _init_kvstore(self): @property def learning_rate(self): - if not isinstance(self._optimizers[0], opt.Optimizer): + if not isinstance(self._optimizer, opt.Optimizer): raise UserWarning("Optimizer has to be defined before its learning " "rate can be accessed.") else: - lr = self._optimizers[0].learning_rate - for i in range(self._contexts): - if self._optimizers[i].learning_rate != lr: - raise UserWarning("The optimizer on %s has a different learning rate" - " from that on %s. Cannot return learning rate") - return lr + return self._optimizer.learning_rate def set_learning_rate(self, lr): """Sets a new learning rate of the optimizer. @@ -275,14 +265,11 @@ def set_learning_rate(self, lr): lr : float The new learning rate of the optimizer. """ - if not self._optimizers: + if not isinstance(self._optimizer, opt.Optimizer): raise UserWarning("Optimizer has to be defined before its learning " "rate is mutated.") - for optim in self._optimizers: - if not isinstance(optim, opt.Optimizer): - raise UserWarning("Optimizer has to be defined before its learning " - "rate is mutated.") - optim.set_learning_rate(lr) + else: + self._optimizer.set_learning_rate(lr) def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): """Internal method to invoke pull operations on KVStore. If `full_idx` is set to True, @@ -301,15 +288,14 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) def _check_and_rescale_grad(self, scale): - for optim in self._optimizers: - if self._update_on_kvstore and self._distributed and self._kv_initialized: - if optim.rescale_grad != scale: - raise UserWarning('Possible change in the `batch_size` from previous ' - '`step` detected. Optimizer gradient normalizing ' - 'factor will not change w.r.t new batch_size when ' - 'update_on_kvstore=True and when distributed kvstore ' - 'is used.') - optim.rescale_grad = scale + if self._update_on_kvstore and self._distributed and self._kv_initialized: + if self._optimizer.rescale_grad != scale: + raise UserWarning('Possible change in the `batch_size` from previous ' + '`step` detected. Optimizer gradient normalizing ' + 'factor will not change w.r.t new batch_size when ' + 'update_on_kvstore=True and when distributed kvstore ' + 'is used.') + self._optimizer.rescale_grad = scale def step(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. Should be called after @@ -448,7 +434,7 @@ def save_states(self, fname): `optimizer.param_dict`, which contains Parameter information (such as `lr_mult` and `wd_mult`) will not be saved. """ - assert self._optimizers and self._optimizers[0] is not None + assert self._optimizer is not None if not self._kv_initialized: self._init_kvstore() @@ -483,14 +469,14 @@ def load_states(self, fname): self._init_params() if self._update_on_kvstore: - self._kvstore.load_optimizer_states(fname) - optimizer = self._kvstore._updater.optimizer - self._init_optimizer(optimizer, None) + self._optimizer = self._kvstore._updater.optimizer + param_dict = {i: param for i, param in enumerate(self._params)} else: with open(fname, 'rb') as f: states = f.read() param_dict = {i: param for i, param in enumerate(self._params)} for updater in self._updaters: updater.set_states(states) - updater.optimizer.param_dict = param_dict - self._optimizers = [updater.optimizer for updater in self._updaters] + updater.optimizer = self._updaters[0].optimizer + self._optimizer = self._updaters[0].optimizer + self._optimizer.param_dict = param_dict diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 97ca8dcd3f17..a49075ebf207 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -55,9 +55,7 @@ def dict_equ(a, b): y.backward() trainer.step(1) - assert len(trainer._optimizers) == 2 - assert len(trainer._updaters) == 2 - assert trainer._optimizers[0].param_dict == trainer._optimizers[1].param_dict + assert trainer._optimizer.param_dict == trainer._optimizer.param_dict assert (x.data(mx.cpu(1)).asnumpy() == -2).all() x.lr_mult = 0.5 @@ -74,18 +72,14 @@ def dict_equ(a, b): trainer.load_states('test_trainer.states') if trainer._update_on_kvstore: dict_equ(trainer._kvstore._updater.states, states) - assert trainer._optimizers[0] == trainer._kvstore._updater.optimizer - assert len(trainer._optimizers) == 2 - assert len(trainer._updaters) == 2 + assert trainer._optimizer == trainer._kvstore._updater.optimizer # invalid usage of update and allreduce_grads if update_on_kvstore assert_raises(AssertionError, trainer.update, 1) assert_raises(AssertionError, trainer.allreduce_grads) else: for updater in trainer._updaters: dict_equ(updater.states, states) - assert trainer._optimizers[0] == trainer._updaters[0].optimizer - assert len(trainer._optimizers) == 2 - assert len(trainer._updaters) == 2 + assert trainer._optimizer == trainer._updaters[0].optimizer x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') @@ -254,4 +248,19 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected): @with_seed() def test_trainer_lr_scheduler(): - pass + x = gluon.Parameter('x', shape=(10,)) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + freq = 2 + factor = 0.1 + lr = 1 + lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr) + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}) + for i in range(10): + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + if i % freq == 0: + assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) + lr *= factor From b79e4d386f55dbc7c7a693f5139ac6d6763aaccc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 06:49:37 +0000 Subject: [PATCH 07/12] fix bug and comment --- python/mxnet/gluon/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 45f8dd546f3d..07bc82fe5932 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -20,7 +20,6 @@ """Parameter optimizer.""" __all__ = ['Trainer'] -import copy from .. import optimizer as opt from ..model import _create_kvstore, _create_sparse_kvstore from .parameter import ParameterDict, Parameter @@ -240,7 +239,7 @@ def _init_kvstore(self): self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore if self._optimizer.lr_scheduler is not None: - assert self._update_on_kvstore, "update_on_kvstore=True does not support " \ + assert self._update_on_kvstore, "update_on_kvstore=False does not support " \ "optimizer with LRScheduler. Please " \ "consider setting learning rate manually." else: @@ -469,14 +468,14 @@ def load_states(self, fname): self._init_params() if self._update_on_kvstore: + self._kvstore.load_optimizer_states(fname) self._optimizer = self._kvstore._updater.optimizer - param_dict = {i: param for i, param in enumerate(self._params)} else: with open(fname, 'rb') as f: states = f.read() - param_dict = {i: param for i, param in enumerate(self._params)} for updater in self._updaters: updater.set_states(states) updater.optimizer = self._updaters[0].optimizer self._optimizer = self._updaters[0].optimizer + param_dict = {i: param for i, param in enumerate(self._params)} self._optimizer.param_dict = param_dict From 8f84f216d31fd71f923d1b6e62bef779e218ad21 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 08:08:14 +0000 Subject: [PATCH 08/12] fix doc rendering and lint --- python/mxnet/gluon/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 07bc82fe5932..7328f6115da1 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -53,8 +53,8 @@ class Trainer(object): Whether to perform parameter updates on kvstore. If None, then trainer will choose the more suitable option depending on the type of kvstore. - Notes - ----- + Note + ---- `update_on_kvstore=False` is not supported in the following cases: - dist kvstore with sparse weights or sparse gradients - dist async kvstore @@ -288,12 +288,12 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): def _check_and_rescale_grad(self, scale): if self._update_on_kvstore and self._distributed and self._kv_initialized: - if self._optimizer.rescale_grad != scale: - raise UserWarning('Possible change in the `batch_size` from previous ' - '`step` detected. Optimizer gradient normalizing ' - 'factor will not change w.r.t new batch_size when ' - 'update_on_kvstore=True and when distributed kvstore ' - 'is used.') + if self._optimizer.rescale_grad != scale: + raise UserWarning('Possible change in the `batch_size` from previous ' + '`step` detected. Optimizer gradient normalizing ' + 'factor will not change w.r.t new batch_size when ' + 'update_on_kvstore=True and when distributed kvstore ' + 'is used.') self._optimizer.rescale_grad = scale def step(self, batch_size, ignore_stale_grad=False): From 9493110890e5d92e2f50f93b352b8b320fd3dc1f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Dec 2018 08:40:38 +0000 Subject: [PATCH 09/12] add invalid sched test --- python/mxnet/gluon/trainer.py | 8 ++++---- tests/python/unittest/test_gluon_trainer.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 7328f6115da1..203de2e8df3b 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -238,10 +238,10 @@ def _init_kvstore(self): kvstore.set_optimizer(self._optimizer) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore - if self._optimizer.lr_scheduler is not None: - assert self._update_on_kvstore, "update_on_kvstore=False does not support " \ - "optimizer with LRScheduler. Please " \ - "consider setting learning rate manually." + if self._optimizer.lr_scheduler and not self._update_on_kvstore: + raise ValueError("update_on_kvstore=False does not support " \ + "optimizer with LRScheduler. Please " \ + "consider setting learning rate manually.") else: self._kvstore = None self._update_on_kvstore = None diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index a49075ebf207..985c38c31356 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -247,7 +247,7 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected): check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', False, ValueError) @with_seed() -def test_trainer_lr_scheduler(): +def test_trainer_lr_sched(): x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') freq = 2 @@ -264,3 +264,21 @@ def test_trainer_lr_scheduler(): if i % freq == 0: assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) lr *= factor + mx.nd.waitall() + +@with_seed() +def test_trainer_invalid_lr_sched(): + x = gluon.Parameter('x', shape=(10,)) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + freq = 2 + factor = 0.1 + lr = 1 + lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr) + invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}, + update_on_kvstore=False) + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + assert_raises(ValueError, invalid_trainer.step, 1) + mx.nd.waitall() From f2aec94a5964413f088776b399bfc0b6449d07fa Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sun, 23 Dec 2018 09:10:15 -0800 Subject: [PATCH 10/12] fix website --- python/mxnet/gluon/trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 203de2e8df3b..b01549fb7b6b 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -28,6 +28,13 @@ class Trainer(object): """Applies an `Optimizer` on a set of Parameters. Trainer should be used together with `autograd`. + .. note:: + + `update_on_kvstore=False` is not supported in the following cases: + - dist kvstore with sparse weights or sparse gradients + - dist async kvstore + - optimizer.lr_scheduler is not None + Parameters ---------- params : ParameterDict @@ -53,13 +60,6 @@ class Trainer(object): Whether to perform parameter updates on kvstore. If None, then trainer will choose the more suitable option depending on the type of kvstore. - Note - ---- - `update_on_kvstore=False` is not supported in the following cases: - - dist kvstore with sparse weights or sparse gradients - - dist async kvstore - - optimizer.lr_scheduler is not None - Properties ---------- learning_rate : float From 3a44452211e97581723644ac382da7e2c103682e Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sun, 23 Dec 2018 10:18:18 -0800 Subject: [PATCH 11/12] trigger --- python/mxnet/gluon/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index b01549fb7b6b..cb7a87e502cd 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -33,7 +33,7 @@ class Trainer(object): `update_on_kvstore=False` is not supported in the following cases: - dist kvstore with sparse weights or sparse gradients - dist async kvstore - - optimizer.lr_scheduler is not None + - `optimizer.lr_scheduler` is not None Parameters ---------- From b77c894165142e1c11c379d04418d59edf19915e Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 27 Dec 2018 15:13:57 -0800 Subject: [PATCH 12/12] update doc --- python/mxnet/gluon/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index cb7a87e502cd..f6c0a31b52e2 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -30,7 +30,9 @@ class Trainer(object): .. note:: - `update_on_kvstore=False` is not supported in the following cases: + For the following cases, updates will always happen on kvstore, + i.e., you cannot set update_on_kvstore=False. + - dist kvstore with sparse weights or sparse gradients - dist async kvstore - `optimizer.lr_scheduler` is not None