From 38671abc9d4ffb32b6439de36f8b1dd887a07d54 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Sat, 9 Mar 2019 08:18:57 -0800 Subject: [PATCH 1/4] LRScheduler with update_on_kvstore=False --- python/mxnet/gluon/trainer.py | 8 ++++---- python/mxnet/optimizer/optimizer.py | 10 +++++++++- tests/python/unittest/test_gluon_trainer.py | 21 ++++++++++++--------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 8060f38ac2aa..0a153384fd4a 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -241,10 +241,10 @@ def _init_kvstore(self): kvstore.set_optimizer(self._optimizer) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore - if self._optimizer.lr_scheduler and not self._update_on_kvstore: - raise ValueError("update_on_kvstore=False does not support " \ - "optimizer with LRScheduler. Please " \ - "consider setting learning rate manually.") + #if self._optimizer.lr_scheduler and not self._update_on_kvstore: + # raise ValueError("update_on_kvstore=False does not support " \ + # "optimizer with LRScheduler. Please " \ + # "consider setting learning rate manually.") else: self._kvstore = None self._update_on_kvstore = None diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index def2c958ede4..aab04a2bfb62 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -106,7 +106,8 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., self.wd_mult = {} self.begin_num_update = begin_num_update self.num_update = begin_num_update - self._index_update_count = {} + self._all_index_update_counts = {0 : {}} + self._index_update_count = self._all_index_update_counts[0] self.clip_gradient = clip_gradient self.multi_precision = multi_precision self.aggregate_num = 0 @@ -380,6 +381,11 @@ def set_wd_mult(self, args_wd_mult): self.wd_mult[name] = float(attr[name]['__wd_mult__']) self.wd_mult.update(args_wd_mult) + def set_current_context(self, ctx): + if ctx not in self._all_index_update_counts: + self._all_index_update_counts[ctx] = {} + self._index_update_count = self._all_index_update_counts[ctx] + def _update_count(self, index): """Updates num_update. @@ -1623,6 +1629,8 @@ def __call__(self, index, grad, weight): indices = index grads = grad weights = weight + if weights: + self.optimizer.set_current_context(weights[0].context.device_id) for i, idx in enumerate(indices): # convert ctypes.char_p.value back to python str if needed if isinstance(idx, bytes): diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 9f190a0a88c2..2d5874a8b97b 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -272,19 +272,22 @@ def test_trainer_lr_sched(): lr *= factor mx.nd.waitall() -@with_seed() -def test_trainer_invalid_lr_sched(): + # Update on kvstore = False x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') freq = 2 factor = 0.1 lr = 1 lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr) - invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}, - update_on_kvstore=False) - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - assert_raises(ValueError, invalid_trainer.step, 1) + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}, + update_on_kvstore=False) + for i in range(10): + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + if i % freq == 0: + assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) + lr *= factor mx.nd.waitall() From 86b8140cacdcf6c613c17436249451ebad4de9e7 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Sun, 10 Mar 2019 09:58:17 -0700 Subject: [PATCH 2/4] Cleaning trainer.py --- python/mxnet/gluon/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 0a153384fd4a..45a44d8eb3e4 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -241,10 +241,6 @@ def _init_kvstore(self): kvstore.set_optimizer(self._optimizer) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore - #if self._optimizer.lr_scheduler and not self._update_on_kvstore: - # raise ValueError("update_on_kvstore=False does not support " \ - # "optimizer with LRScheduler. Please " \ - # "consider setting learning rate manually.") else: self._kvstore = None self._update_on_kvstore = None From 55e0560faf13a221f2adba4dcb90581357fef856 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 11 Mar 2019 14:58:47 -0700 Subject: [PATCH 3/4] Retrigger CI From a6b6ef621d743781876f6e47671273081ca39008 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 15 Mar 2019 11:25:41 -0700 Subject: [PATCH 4/4] Fixes from review --- python/mxnet/optimizer/optimizer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index aab04a2bfb62..2e7fe86c5af9 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -381,10 +381,17 @@ def set_wd_mult(self, args_wd_mult): self.wd_mult[name] = float(attr[name]['__wd_mult__']) self.wd_mult.update(args_wd_mult) - def set_current_context(self, ctx): - if ctx not in self._all_index_update_counts: - self._all_index_update_counts[ctx] = {} - self._index_update_count = self._all_index_update_counts[ctx] + def _set_current_context(self, device_id): + """Sets the number of the currently handled device. + + Parameters + ---------- + device_id : int + The number of current device. + """ + if device_id not in self._all_index_update_counts: + self._all_index_update_counts[device_id] = {} + self._index_update_count = self._all_index_update_counts[device_id] def _update_count(self, index): """Updates num_update. @@ -1630,7 +1637,7 @@ def __call__(self, index, grad, weight): grads = grad weights = weight if weights: - self.optimizer.set_current_context(weights[0].context.device_id) + self.optimizer._set_current_context(weights[0].context.device_id) for i, idx in enumerate(indices): # convert ctypes.char_p.value back to python str if needed if isinstance(idx, bytes):