Skip to content

Commit

Permalink
[skip ci] Add doctest for LinearCyclicalScheduler (pytorch#2327)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdesrozis authored and fco-dv committed Nov 23, 2021
1 parent 32515c3 commit ed2f0ee
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 11 deletions.
34 changes: 29 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def run(self):

# doctest config
doctest_global_setup = """
from collections import OrderedDict
import torch
from torch import nn, optim
Expand All @@ -340,18 +342,40 @@ def run(self):
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *
manual_seed(666)
# create default evaluator for doctests
def process_function(engine, batch):
def eval_step(engine, batch):
y_pred, y = batch
return y_pred, y
default_evaluator = Engine(process_function)
default_evaluator = Engine(eval_step)
# create default optimizer for doctests
param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0)
# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must defined his own trainer using `.. testsetup:`
def get_default_trainer():
def train_step(engine, batch):
return 0.0
return Engine(train_step)
# create default model for doctests
default_model = nn.Sequential(OrderedDict([
('base', nn.Linear(4, 2)),
('fc', nn.Linear(2, 1))
]))
manual_seed(666)
"""


Expand Down
68 changes: 62 additions & 6 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,71 @@ class LinearCyclicalScheduler(CyclicalScheduler):
usually be the number of batches in an epoch.
Examples:
.. code-block:: python
.. testsetup:: *
default_trainer = get_default_trainer()
.. testcode:: 1
from ignite.handlers.param_scheduler import LinearCyclicalScheduler
scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1e-3, 1e-1, len(train_loader))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
#
# Linearly increases the learning rate from 1e-3 to 1e-1 and back to 1e-3
# over the course of 1 epoch
# Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0
# over a cycle of 4 iterations
scheduler = LinearCyclicalScheduler(default_optimizer, "lr", 0.0, 1.0, 4)
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
print(default_optimizer.param_groups[0]["lr"])
default_trainer.run([0] * 9, max_epochs=1)
.. testoutput:: 1
0.0
0.5
1.0
0.5
...
.. testcode:: 2
from ignite.handlers.param_scheduler import LinearCyclicalScheduler
optimizer = torch.optim.SGD(
[
{"params": default_model.base.parameters(), "lr": 0.001},
{"params": default_model.fc.parameters(), "lr": 0.01},
]
)
# Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0
# over a cycle of 4 iterations
scheduler1 = LinearCyclicalScheduler(optimizer, "lr", 0.0, 1.0, 4, param_group_index=0)
# Linearly increases the learning rate from 1.0 to 0.0 and back to 0.1
# over a cycle of 4 iterations
scheduler2 = LinearCyclicalScheduler(optimizer, "lr", 0.0, 0.1, 4, param_group_index=1)
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1)
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2)
@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
print(optimizer.param_groups[0]["lr"],
optimizer.param_groups[1]["lr"])
default_trainer.run([0] * 9, max_epochs=1)
.. testoutput:: 2
0.0 0.0
0.5 0.05
1.0 0.1
0.5 0.05
...
.. versionadded:: 0.4.5
"""
Expand Down

0 comments on commit ed2f0ee

Please sign in to comment.