@@ -460,21 +460,48 @@ class MultiStepStateScheduler(StateParamScheduler):
460
460
461
461
Examples:
462
462
463
- .. code-block:: python
463
+ .. testsetup::
464
464
465
- ...
466
- engine = Engine(train_step)
465
+ default_trainer = get_default_trainer()
466
+
467
+ .. testcode::
467
468
468
469
param_scheduler = MultiStepStateScheduler(
469
- param_name="param", initial_value=10 , gamma=0.99 , milestones=[3, 6],
470
+ param_name="param", initial_value=1 , gamma=0.9 , milestones=[3, 6, 9, 12]
470
471
)
471
472
472
- param_scheduler.attach(engine, Events.EPOCH_COMPLETED)
473
+ # parameter is param, initial_value sets param to 1, gamma is set as 0.9
474
+ # Epoch 1 to 2, param does not change as milestone is 3
475
+ # Epoch 3, param changes from 1 to 1*0.9, param = 0.9
476
+ # Epoch 3 to 5, param does not change as milestone is 6
477
+ # Epoch 6, param changes from 0.9 to 0.9*0.9, param = 0.81
478
+ # Epoch 6 to 8, param does not change as milestone is 9
479
+ # Epoch 9, param changes from 0.81 to 0.81*0.9, param = 0.729
480
+ # Epoch 9 to 11, param does not change as milestone is 12
481
+ # Epoch 12, param changes from 0.729 to 0.729*0.9, param = 0.6561
482
+
483
+ param_scheduler.attach(default_trainer, Events.EPOCH_COMPLETED)
484
+
485
+ @default_trainer.on(Events.EPOCH_COMPLETED)
486
+ def print_param():
487
+ print(default_trainer.state.param)
488
+
489
+ default_trainer.run([0], max_epochs=12)
473
490
474
- # basic handler to print scheduled state parameter
475
- engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _ : print(engine.state.param))
491
+ .. testoutput::
476
492
477
- engine.run([0] * 8, max_epochs=10)
493
+ 1.0
494
+ 1.0
495
+ 0.9
496
+ 0.9
497
+ 0.9
498
+ 0.81
499
+ 0.81
500
+ 0.81
501
+ 0.7290...
502
+ 0.7290...
503
+ 0.7290...
504
+ 0.6561
478
505
479
506
.. versionadded:: 0.5.0
480
507
0 commit comments