Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix learning rate scheduler being unexpectedly overwritten by optimizer's default value #16487

Merged
merged 7 commits into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ class Optimizer(object):
clip_gradient : float, optional, default None
Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.

learning_rate : float, optional, default 0.01
The initial learning rate.
learning_rate : float, optional, default None
The initial learning rate. If None, the optimization will use the
learning rate from ``lr_scheduler``. If not None, it will overwrite
wkcn marked this conversation as resolved.
Show resolved Hide resolved
the learning rate in ``lr_scheduler``.

lr_scheduler : LRScheduler, optional, default None
The learning rate scheduler.
Expand Down Expand Up @@ -97,14 +99,19 @@ class Optimizer(object):
optimizer, its learning rate can be accessed as optimizer.learning_rate.
"""
def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
clip_gradient=None, learning_rate=0.01,
clip_gradient=None, learning_rate=None,
lr_scheduler=None, sym=None, begin_num_update=0,
multi_precision=False, param_dict=None):
self.rescale_grad = rescale_grad
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self.lr_scheduler.base_lr = learning_rate
if self.lr_scheduler is None and learning_rate is None:
learning_rate = 0.01
self.lr = learning_rate
if self.lr_scheduler is not None and learning_rate is not None:
if self.lr_scheduler.base_lr != learning_rate:
raise UserWarning("learning rate from ``lr_scheduler`` has been "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this terminate the execution and start the exception handling? If so, the statement below wouldn't be executed

"overwritten by ``learning_rate`` in optimizer.")
self.lr_scheduler.base_lr = learning_rate

self.wd = wd
self.lr_mult = {}
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_learning_rate():
o2.lr_scheduler.base_lr = 0.4
assert o2.learning_rate == 0.4

lr_s = lr_scheduler.FactorScheduler(base_lr=1024)
o3 = mx.optimizer.Optimizer(lr_scheduler=lr_s)
assert o3.learning_rate == 1024


@raises(UserWarning)
@with_seed()
Expand Down