Skip to content

Commit

Permalink
[MXNET-535] Fix bugs in LR Schedulers and add warmup (apache#11234)
Browse files Browse the repository at this point in the history
* Add warmup and fix inconsistencies with learning rate schedulers

* add comments

* remove assert
  • Loading branch information
rahul003 authored and anirudh2290 committed Sep 19, 2018
1 parent 032f249 commit 0b58ceb
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 19 deletions.
3 changes: 2 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def _get_lr_scheduler(args, kv):
steps = [epoch_size * (x - begin_epoch)
for x in step_epochs if x - begin_epoch > 0]
if steps:
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor,
base_lr=args.lr))
else:
return (lr, None)

Expand Down
147 changes: 129 additions & 18 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Scheduling learning rate."""
import logging
from math import cos, pi

class LRScheduler(object):
"""Base class of a learning rate scheduler.
Expand All @@ -28,9 +29,41 @@ class LRScheduler(object):
----------
base_lr : float, optional
The initial learning rate.
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""
def __init__(self, base_lr=0.01):
def __init__(self, base_lr=0.01,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
self.base_lr = base_lr
assert isinstance(warmup_steps, int)
self.warmup_steps = warmup_steps

self.warmup_final_lr = base_lr
self.warmup_begin_lr = warmup_begin_lr
if self.warmup_begin_lr > self.warmup_final_lr:
raise ValueError("Base lr has to be higher than warmup_begin_lr")
if self.warmup_steps < 0:
raise ValueError("Warmup steps has to be positive or 0")
if warmup_mode not in ['linear', 'constant']:
raise ValueError("Supports only linear and constant modes of warmup")
self.warmup_mode = warmup_mode

def get_warmup_lr(self, num_update):
assert num_update < self.warmup_steps
if self.warmup_mode == 'linear':
increase = (self.warmup_final_lr - self.warmup_begin_lr) \
* float(num_update) / float(self.warmup_steps)
return self.warmup_begin_lr + increase
elif self.warmup_mode == 'constant':
return self.warmup_begin_lr
else:
raise ValueError("Invalid warmup mode %s"%self.warmup_mode)

def __call__(self, num_update):
"""Return a new learning rate.
Expand Down Expand Up @@ -66,8 +99,9 @@ class FactorScheduler(LRScheduler):
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8):
super(FactorScheduler, self).__init__()
def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(FactorScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
Expand All @@ -78,6 +112,9 @@ def __init__(self, step, factor=1, stop_factor_lr=1e-8):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while num_update > self.count + self.step:
self.count += self.step
Expand Down Expand Up @@ -108,9 +145,19 @@ class MultiFactorScheduler(LRScheduler):
The list of steps to schedule a change
factor: float
The factor to change the learning rate.
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""
def __init__(self, step, factor=1):
super(MultiFactorScheduler, self).__init__()
def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0,
warmup_mode='linear'):
super(MultiFactorScheduler, self).__init__(base_lr, warmup_steps,
warmup_begin_lr, warmup_mode)
assert isinstance(step, list) and len(step) >= 1
for i, _step in enumerate(step):
if i != 0 and step[i] <= step[i-1]:
Expand All @@ -125,6 +172,9 @@ def __init__(self, step, factor=1):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while self.cur_step_ind <= len(self.step)-1:
if num_update > self.step[self.cur_step_ind]:
Expand All @@ -138,33 +188,94 @@ def __call__(self, num_update):
return self.base_lr

class PolyScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.
""" Reduce the learning rate according to a polynomial of given power.
Calculate the new learning rate by::
Calculate the new learning rate, after warmup if any, by::
base_lr * (1-nup/max_nup)^pwr
final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr
if nup < max_nup, 0 otherwise.
Parameters
----------
max_update: maximum number of updates before the decay reaches 0.
base_lr: base learning rate
pwr: power of the decay term as a funtion of the current number of updates.
max_update: int
maximum number of updates before the decay reaches final learning rate.
base_lr: float
base learning rate to start from
pwr: int
power of the decay term as a function of the current number of updates.
final_lr: float
final learning rate after all steps
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(PolyScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.power = pwr
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.base_lr = self.base_lr_orig
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
pow(1 - float(num_update - self.warmup_steps) / float(self.max_steps), self.power)
return self.base_lr

class CosineScheduler(LRScheduler):
""" Reduce the learning rate according to a cosine function
Calculate the new learning rate by::
final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2
if nup < max_nup, 0 otherwise.
Parameters
----------
max_update: int
maximum number of updates before the decay reaches 0
base_lr: float
base learning rate
final_lr: float
final learning rate after all steps
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""

def __init__(self, max_update, base_lr=0.01, final_lr=0,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(CosineScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = base_lr
self.max_update = max_update
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
(1 + cos(pi * (num_update - self.warmup_steps) / self.max_steps)) / 2
return self.base_lr
52 changes: 52 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,58 @@ def test_adagrad():
g_stype='row_sparse')


def test_factor_scheduler():
base_lr = 1
step = 100
factor = 0.1
sched = mx.lr_scheduler.FactorScheduler(step, factor, stop_factor_lr=1e-4, base_lr=base_lr,
warmup_steps=20, warmup_begin_lr=0.1, warmup_mode='constant')

assert (sched(0) == 0.1)
np.testing.assert_almost_equal(sched(10), 0.1)
assert (sched(21) == base_lr), sched(21)
np.testing.assert_almost_equal(sched(101), base_lr * factor)
np.testing.assert_almost_equal(sched(201), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(1000), 1e-4)

def test_multifactor_scheduler():
base_lr = 0.1
steps = [15, 25]
factor = 0.1
sched = mx.lr_scheduler.MultiFactorScheduler(steps, factor, base_lr=base_lr,
warmup_steps=10, warmup_begin_lr=0.05, warmup_mode='linear')

assert sched(0) == 0.05
np.testing.assert_almost_equal(sched(5), 0.05 + (base_lr - 0.05)/2)
np.testing.assert_almost_equal(sched(15), base_lr)
np.testing.assert_almost_equal(sched(16), base_lr * factor)
np.testing.assert_almost_equal(sched(20), base_lr * factor)
np.testing.assert_almost_equal(sched(26), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(100), base_lr * factor * factor)

def test_poly_scheduler():
base_lr = 3
final_lr = 0
steps = 1000
poly_sched = mx.lr_scheduler.PolyScheduler(steps, base_lr=base_lr, pwr=2, final_lr=final_lr,
warmup_steps=100, warmup_begin_lr=0, warmup_mode='linear')

np.testing.assert_almost_equal(poly_sched(0), 0)
np.testing.assert_almost_equal(poly_sched(50), float(base_lr)/2)
np.testing.assert_almost_equal(poly_sched(100), base_lr)
assert (poly_sched(101) < poly_sched(100))
assert (poly_sched(500) < 1.6)
np.testing.assert_almost_equal(poly_sched(steps), final_lr)

def test_cosine_scheduler():
# also tests case without warmup
base_lr = 3
final_lr = 0.1
steps = 1000
cosine_sched = mx.lr_scheduler.CosineScheduler(steps, base_lr=base_lr, final_lr=final_lr)
np.testing.assert_almost_equal(cosine_sched(0), base_lr)
np.testing.assert_almost_equal(cosine_sched(steps), final_lr)
assert (cosine_sched(500) > 1.5)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 0b58ceb

Please sign in to comment.