-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
lr_scheduler.py
147 lines (128 loc) · 4.61 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Popular Learning Rate Schedulers"""
# pylint: disable=missing-docstring
from __future__ import division
import warnings
from math import pi, cos
from mxnet import lr_scheduler
class LRSequential(lr_scheduler.LRScheduler):
r"""Compose Learning Rate Schedulers
Parameters
----------
schedulers: list
list of LRScheduler objects
"""
def __init__(self, schedulers):
super(LRSequential, self).__init__()
assert(len(schedulers) > 0)
self.update_sep = []
self.count = 0
self.learning_rate = 0
self.schedulers = []
for lr in schedulers:
self.add(lr)
def add(self, scheduler):
assert(isinstance(scheduler, LRScheduler))
scheduler.offset = self.count
self.count += scheduler.niters
self.update_sep.append(self.count)
self.schedulers.append(scheduler)
def __call__(self, num_update):
self.update(num_update)
return self.learning_rate
def update(self, num_update):
num_update = min(num_update, self.count - 1)
ind = len(self.schedulers) - 1
for i, sep in enumerate(self.update_sep):
if sep > num_update:
ind = i
break
lr = self.schedulers[ind]
lr.update(num_update)
self.learning_rate = lr.learning_rate
class LRScheduler(lr_scheduler.LRScheduler):
r"""Learning Rate Scheduler
Parameters
----------
mode : str
Modes for learning rate scheduler.
Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
base_lr : float
Base learning rate, i.e. the starting learning rate.
target_lr : float
Target learning rate, i.e. the ending learning rate.
With constant mode target_lr is ignored.
niters : int
Number of iterations to be scheduled.
nepochs : int
Number of epochs to be scheduled.
iters_per_epoch : int
Number of iterations in each epoch.
offset : int
Number of iterations before this scheduler.
power : float
Power parameter of poly scheduler.
step_iter : list
A list of iterations to decay the learning rate.
step_epoch : list
A list of epochs to decay the learning rate.
step_factor : float
Learning rate decay factor.
"""
def __init__(self, mode, base_lr=0.1, target_lr=0,
niters=0, nepochs=0, iters_per_epoch=0, offset=0,
power=2, step_iter=None, step_epoch=None, step_factor=0.1,
baselr=None, targetlr=None):
super(LRScheduler, self).__init__()
assert(mode in ['constant', 'step', 'linear', 'poly', 'cosine'])
self.mode = mode
if mode == 'step':
assert(step_iter is not None or step_epoch is not None)
if baselr is not None:
warnings.warn("baselr is deprecated. Please use base_lr.")
if base_lr == 0.1:
base_lr = baselr
self.base_lr = base_lr
if targetlr is not None:
warnings.warn("targetlr is deprecated. Please use target_lr.")
if target_lr == 0:
target_lr = targetlr
self.target_lr = target_lr
if self.mode == 'constant':
self.target_lr = self.base_lr
self.niters = niters
self.step = step_iter
epoch_iters = nepochs * iters_per_epoch
if epoch_iters > 0:
self.niters = epoch_iters
if step_epoch is not None:
self.step = [s*iters_per_epoch for s in step_epoch]
self.offset = offset
self.power = power
self.step_factor = step_factor
def __call__(self, num_update):
self.update(num_update)
return self.learning_rate
def update(self, num_update):
N = self.niters - 1
T = num_update - self.offset
T = min(max(0, T), N)
if self.mode == 'constant':
factor = 0
elif self.mode == 'linear':
factor = 1 - T / N
elif self.mode == 'poly':
factor = pow(1 - T / N, self.power)
elif self.mode == 'cosine':
factor = (1 + cos(pi * T / N)) / 2
elif self.mode == 'step':
if self.step is not None:
count = sum([1 for s in self.step if s <= T])
factor = pow(self.step_factor, count)
else:
factor = 1
else:
raise NotImplementedError
if self.mode == 'step':
self.learning_rate = self.base_lr * factor
else:
self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor