Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-535] Fix bugs in LR Schedulers and add warmup #11234

Merged
merged 3 commits into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def _get_lr_scheduler(args, kv):
steps = [epoch_size * (x - begin_epoch)
for x in step_epochs if x - begin_epoch > 0]
if steps:
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor,
base_lr=args.lr))
else:
return (lr, None)

Expand Down
105 changes: 89 additions & 16 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Scheduling learning rate."""
import logging
from math import cos, pi

class LRScheduler(object):
"""Base class of a learning rate scheduler.
Expand All @@ -29,8 +30,31 @@ class LRScheduler(object):
base_lr : float, optional
The initial learning rate.
"""
def __init__(self, base_lr=0.01):
def __init__(self, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding documentation for warmup_begin_lr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some for the inherited classes, but not for this base abstract class. Anyway, now added for all. Please check.

self.base_lr = base_lr
assert isinstance(warmup_steps, int)
self.warmup_steps = warmup_steps

self.warmup_final_lr = base_lr
self.warmup_begin_lr = warmup_begin_lr
if self.warmup_begin_lr > self.warmup_final_lr:
raise ValueError("Base lr has to be higher than warmup_begin_lr")
if self.warmup_steps < 0:
raise ValueError("Warmup steps has to be positive or 0")
if warmup_mode not in ['linear', 'constant']:
raise ValueError("Supports only linear and constant modes of warmup")
self.warmup_mode = warmup_mode

def get_warmup_lr(self, num_update):
assert num_update < self.warmup_steps
if self.warmup_mode == 'linear':
increase = (self.warmup_final_lr - self.warmup_begin_lr) \
* float(num_update)/float(self.warmup_steps)
return self.warmup_begin_lr + increase
elif self.warmup_mode == 'constant':
return self.warmup_begin_lr
else:
raise ValueError("Invalid warmup mode %s"%self.warmup_mode)

def __call__(self, num_update):
"""Return a new learning rate.
Expand Down Expand Up @@ -66,8 +90,9 @@ class FactorScheduler(LRScheduler):
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8):
super(FactorScheduler, self).__init__()
def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(FactorScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
Expand All @@ -78,6 +103,9 @@ def __init__(self, step, factor=1, stop_factor_lr=1e-8):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while num_update > self.count + self.step:
self.count += self.step
Expand Down Expand Up @@ -109,8 +137,10 @@ class MultiFactorScheduler(LRScheduler):
factor: float
The factor to change the learning rate.
"""
def __init__(self, step, factor=1):
super(MultiFactorScheduler, self).__init__()
def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0,
warmup_mode='linear'):
super(MultiFactorScheduler, self).__init__(base_lr, warmup_steps,
warmup_begin_lr, warmup_mode)
assert isinstance(step, list) and len(step) >= 1
for i, _step in enumerate(step):
if i != 0 and step[i] <= step[i-1]:
Expand All @@ -125,6 +155,9 @@ def __init__(self, step, factor=1):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while self.cur_step_ind <= len(self.step)-1:
if num_update > self.step[self.cur_step_ind]:
Expand All @@ -138,33 +171,73 @@ def __call__(self, num_update):
return self.base_lr

class PolyScheduler(LRScheduler):
""" Reduce the learning rate according to a polynomial of given power.

Calculate the new learning rate by::

final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr
if nup < max_nup, 0 otherwise.

Parameters
----------
max_update: maximum number of updates before the decay reaches final learning rate.
base_lr: base learning rate to start from
pwr: power of the decay term as a function of the current number of updates.
final_lr: final learning rate after all steps
warmup_steps: number of warmup steps used before this scheduler starts decay
"""

def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(PolyScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.power = pwr
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
pow(1 - float(num_update - self.warmup_steps) / float(self.max_steps), self.power)
return self.base_lr

class CosineScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.

Calculate the new learning rate by::

base_lr * (1-nup/max_nup)^pwr
final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2
if nup < max_nup, 0 otherwise.

Parameters
----------
max_update: maximum number of updates before the decay reaches 0.
max_update: maximum number of updates before the decay reaches 0
base_lr: base learning rate
pwr: power of the decay term as a funtion of the current number of updates.

final_lr: final learning rate after all steps
warmup_steps: number of warmup steps used before this scheduler starts decay
"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't remove base_lr, it will break API.
Pass it to super init instead

super(PolyScheduler, self).__init__(base_lr)
def __init__(self, max_update, base_lr=0.01, final_lr=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove pwr? This is API breakage

Copy link
Member Author

@rahul003 rahul003 Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not removed it. Git is getting confused :/ It thinks I've changed PolyScheduler to CosineScheduler when in fact I've modified PolyScheduler and added a new CosineScheduler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer #11234 (comment)

warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(CosineScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = self.base_lr
self.base_lr_orig = base_lr
self.max_update = max_update
self.power = pwr
self.base_lr = self.base_lr_orig
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
(1 + cos(pi * (num_update - self.warmup_steps) / self.max_steps)) / 2
return self.base_lr
49 changes: 49 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,55 @@ def test_adagrad():
w_stype='row_sparse', g_stype='row_sparse')


def test_factor_scheduler():
base_lr = 1
step = 100
factor = 0.1
sched = mx.lr_scheduler.FactorScheduler(step, factor, stop_factor_lr=1e-4, base_lr=base_lr,
warmup_steps=20, warmup_begin_lr=0.1, warmup_mode='constant')
assert (sched(0) == 0.1)
np.testing.assert_almost_equal(sched(10), 0.1)
assert (sched(21) == base_lr), sched(21)
np.testing.assert_almost_equal(sched(101), base_lr * factor)
np.testing.assert_almost_equal(sched(201), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(1000), 1e-4)

def test_multifactor_scheduler():
base_lr = 0.1
steps = [15, 25]
factor = 0.1
sched = mx.lr_scheduler.MultiFactorScheduler(steps, factor, base_lr=base_lr,
warmup_steps=10, warmup_begin_lr=0.05, warmup_mode='linear')
assert sched(0) == 0.05
np.testing.assert_almost_equal(sched(5), 0.05 + (base_lr - 0.05)/2)
np.testing.assert_almost_equal(sched(15), base_lr)
np.testing.assert_almost_equal(sched(16), base_lr * factor)
np.testing.assert_almost_equal(sched(20), base_lr * factor)
np.testing.assert_almost_equal(sched(26), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(100), base_lr * factor * factor)

def test_poly_scheduler():
base_lr = 3
final_lr = 0
steps = 1000
poly_sched = mx.lr_scheduler.PolyScheduler(steps, base_lr=base_lr, pwr=2, final_lr=final_lr,
warmup_steps=100, warmup_begin_lr=0, warmup_mode='linear')
np.testing.assert_almost_equal(poly_sched(0), 0)
np.testing.assert_almost_equal(poly_sched(50), float(base_lr)/2)
np.testing.assert_almost_equal(poly_sched(100), base_lr)
assert (poly_sched(101) < poly_sched(100))
assert (poly_sched(500) < 1.6)
np.testing.assert_almost_equal(poly_sched(steps), final_lr)

def test_cosine_scheduler():
# also tests case without warmup
base_lr = 3
final_lr = 0.1
steps = 1000
cosine_sched = mx.lr_scheduler.CosineScheduler(steps, base_lr=base_lr, final_lr=final_lr)
np.testing.assert_almost_equal(cosine_sched(0), base_lr)
np.testing.assert_almost_equal(cosine_sched(steps), final_lr)
assert (cosine_sched(500) > 1.5)

if __name__ == '__main__':
import nose
Expand Down