Skip to content

Commit 4f61e20

Browse files
Add doctest for MultiStepScheduler (#2399)
1 parent 5a17cd0 commit 4f61e20

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

Diff for: ignite/handlers/state_param_scheduler.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -460,21 +460,48 @@ class MultiStepStateScheduler(StateParamScheduler):
460460
461461
Examples:
462462
463-
.. code-block:: python
463+
.. testsetup::
464464
465-
...
466-
engine = Engine(train_step)
465+
default_trainer = get_default_trainer()
466+
467+
.. testcode::
467468
468469
param_scheduler = MultiStepStateScheduler(
469-
param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6],
470+
param_name="param", initial_value=1, gamma=0.9, milestones=[3, 6, 9, 12]
470471
)
471472
472-
param_scheduler.attach(engine, Events.EPOCH_COMPLETED)
473+
# parameter is param, initial_value sets param to 1, gamma is set as 0.9
474+
# Epoch 1 to 2, param does not change as milestone is 3
475+
# Epoch 3, param changes from 1 to 1*0.9, param = 0.9
476+
# Epoch 3 to 5, param does not change as milestone is 6
477+
# Epoch 6, param changes from 0.9 to 0.9*0.9, param = 0.81
478+
# Epoch 6 to 8, param does not change as milestone is 9
479+
# Epoch 9, param changes from 0.81 to 0.81*0.9, param = 0.729
480+
# Epoch 9 to 11, param does not change as milestone is 12
481+
# Epoch 12, param changes from 0.729 to 0.729*0.9, param = 0.6561
482+
483+
param_scheduler.attach(default_trainer, Events.EPOCH_COMPLETED)
484+
485+
@default_trainer.on(Events.EPOCH_COMPLETED)
486+
def print_param():
487+
print(default_trainer.state.param)
488+
489+
default_trainer.run([0], max_epochs=12)
473490
474-
# basic handler to print scheduled state parameter
475-
engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _ : print(engine.state.param))
491+
.. testoutput::
476492
477-
engine.run([0] * 8, max_epochs=10)
493+
1.0
494+
1.0
495+
0.9
496+
0.9
497+
0.9
498+
0.81
499+
0.81
500+
0.81
501+
0.7290...
502+
0.7290...
503+
0.7290...
504+
0.6561
478505
479506
.. versionadded:: 0.5.0
480507

0 commit comments

Comments
 (0)