Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate stepping option #941

Merged
merged 34 commits into from
Mar 5, 2020
Merged

Learning rate stepping option #941

merged 34 commits into from
Mar 5, 2020

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Feb 25, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

Fixes #806
Fixes part of #827
Fixes #945

This PR will allow the user to choose the frequency at which the learning rate schedulers are called by supplying it as:

def configure_optimizer(self):
    optimizer = ...
    scheduler = ...
    return [optimizer], [[scheduler, freq]]

this will call scheduler.step() every freq training step. The user can still just write

    return [optimizer], [scheduler]

which (like now) will call scheduler.step() after every epoch. This way the changes should be backward compatible. This may not be the most intuitive way to do it and I am happy to discuss another way of doing it.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@williamFalcon
Copy link
Contributor

@SkafteNicki awesome haha. Can you replace it with a dict instead of array? (#945)

Fixes #945

@williamFalcon
Copy link
Contributor

@SkafteNicki mind rebase as well? :)

@williamFalcon williamFalcon added this to the 0.6.1 milestone Feb 25, 2020
@SkafteNicki
Copy link
Member Author

@williamFalcon will do the rebase and change structure to a dict within the next day or so. You are right that for future features it is probably best to have it in a dict like structure. I propose to have three fields:
{'scheduler': LRScheduler, 'interval': 'batch|epoch'`, 'frequency': integer}
where 'interval' is as default set to 'epoch' and 'frequency' is set to 1. Should probably still allow user to just input list of LRSchedulers for backward compatible.

@williamFalcon
Copy link
Contributor

cool. do we need frequency?
@srush @ethanwharris

@ethanwharris
Copy link
Member

This looks nice :) Personally I would opt for a batch / epoch option as suggested by @williamFalcon in #945

There may be some value to a frequency argument for iterable stuff (where there's no real notion of an epoch) - but perhaps the options could be 'epoch', 'batch', or int?

It should also work smoothly with the accumulate grads stuff, i.e. 'batch' should really mean call to .step

@srush
Copy link
Contributor

srush commented Feb 26, 2020 via email

@pep8speaks
Copy link

pep8speaks commented Feb 26, 2020

Hello @SkafteNicki! Thanks for updating this PR.

Line 786:101: E501 line too long (119 > 100 characters)

Line 34:101: E501 line too long (103 > 100 characters)
Line 35:101: E501 line too long (101 > 100 characters)
Line 37:101: E501 line too long (109 > 100 characters)
Line 70:101: E501 line too long (109 > 100 characters)

Line 725:101: E501 line too long (101 > 100 characters)

Line 85:101: E501 line too long (116 > 100 characters)

Line 48:101: E501 line too long (110 > 100 characters)
Line 88:101: E501 line too long (114 > 100 characters)

Comment last updated at 2020-03-05 08:42:37 UTC

@Borda Borda added the feature Is an improvement or enhancement label Feb 26, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it becomes quite long... what about LR move to separate mixin, also tests to test_trainer_lr

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_io.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
@williamFalcon
Copy link
Contributor

williamFalcon commented Feb 27, 2020

@SkafteNicki awesome! Let's rebase so we can get this merged :)

@williamFalcon
Copy link
Contributor

@SkafteNicki fix tests?

@SkafteNicki
Copy link
Member Author

@williamFalcon the failing test seems to be unrelated to the PR, tests were passing before the latest merge

@ethanwharris
Copy link
Member

Ok, so merging master has caused some issues here, the callback test in test_trainer.py now exists in tests/trainer/test_callbacks.py

@SkafteNicki - the simplest thing to do would be to have a look at the diff and just remove the bits that weren't you - or revert the merge and do a rebase

@williamFalcon
Copy link
Contributor

@SkafteNicki
awesome! very close.

have some failing GPU tests.

_____________________________________________________________________________________________ test_amp_gpu_ddp _____________________________________________________________________________________________

tmpdir = local('/tmp/pytest-of-waf251/pytest-45/test_amp_gpu_ddp0')

    def test_amp_gpu_ddp(tmpdir):
        """Make sure DDP + AMP work."""
        if not tutils.can_run_gpu_test():
            return

        tutils.reset_seed()
        tutils.set_random_master_port()

        hparams = tutils.get_hparams()
        model = LightningTestModel(hparams)

        trainer_options = dict(
            default_save_path=tmpdir,
            show_progress_bar=True,
            max_epochs=1,
            gpus=2,
            distributed_backend='ddp',
            precision=16
        )

>       tutils.run_model_test(trainer_options, model)

tests/test_amp.py:81:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/models/utils.py:88: in run_model_test
    trainer.hpc_save(save_dir, logger)
pytorch_lightning/trainer/training_io.py:431: in hpc_save
    checkpoint = self.dump_checkpoint()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pytorch_lightning.trainer.trainer.Trainer object at 0x7f677f977bd0>

    def dump_checkpoint(self):
        checkpoint = {
            'epoch': self.current_epoch + 1,
            'global_step': self.global_step + 1,
        }

        if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
            checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best

        if self.early_stop_callback is not None and self.checkpoint_callback is not False:
            checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
            checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience

        # save optimizers
        optimizer_states = []
        for i, optimizer in enumerate(self.optimizers):
            optimizer_states.append(optimizer.state_dict())

        checkpoint['optimizer_states'] = optimizer_states

        # save lr schedulers
        lr_schedulers = []
        for scheduler in self.lr_schedulers:
>           lr_schedulers.append(scheduler['scheduler'].state_dict())
E           TypeError: 'CosineAnnealingLR' object is not subscriptable


@williamFalcon
Copy link
Contributor

@SkafteNicki any chance you can finish this in the next few hours so we can merge and release?

@SkafteNicki
Copy link
Member Author

@williamFalcon, sorry for the delay, been really busy the last couple of days.
The latest change should fix the gpu test.

The failing checks all seems to be course by (nothing to do with this PR)
urllib.error.HTTPError: HTTP Error 403: Forbidden
which has something to do with MNIST downloading I guess

@Borda
Copy link
Member

Borda commented Mar 4, 2020

The failing checks all seems to be course by (nothing to do with this PR)
urllib.error.HTTPError: HTTP Error 403: Forbidden
which has something to do with MNIST downloading I guess

Yeah, it is happening from yesterday...

@SkafteNicki
Copy link
Member Author

@Borda anything I can do?

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the monitoring "val_loss" as a parameter?

pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
tests/trainer/test_optimizers.py Outdated Show resolved Hide resolved
tests/trainer/test_optimizers.py Outdated Show resolved Hide resolved
tests/trainer/test_optimizers.py Outdated Show resolved Hide resolved
tests/trainer/test_optimizers.py Outdated Show resolved Hide resolved
tests/trainer/test_optimizers.py Outdated Show resolved Hide resolved
@Borda
Copy link
Member

Borda commented Mar 4, 2020

@Borda anything I can do?

I guess it is fine (maybe check las Wills comment)... Great job again!
We will try to do something around the downloading datasets... :/

@SkafteNicki
Copy link
Member Author

@Borda I have changed code such that users can also pass a monitor key, which ReduceLROnPlauteau schedulers are then conditioned on. As default it is set to val_loss.

@williamFalcon
Copy link
Contributor

______________________________________________________________________________________ test_optimizer_return_options _______________________________________________________________________________________

    def test_optimizer_return_options():
        tutils.reset_seed()

        trainer = Trainer()
        model, hparams = tutils.get_model()

        # single optimizer
        opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
        opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
        optim, lr_sched = trainer.init_optimizers(opt_a)
        assert len(optim) == 1 and len(lr_sched) == 0

        # opt tuple
        opts = (opt_a, opt_b)
        optim, lr_sched = trainer.init_optimizers(opts)
        assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
        assert len(lr_sched) == 0

        # opt list
        opts = [opt_a, opt_b]
        optim, lr_sched = trainer.init_optimizers(opts)
        assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
        assert len(lr_sched) == 0

        # opt tuple of lists
        scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
        opts = ([opt_a], [scheduler])
        optim, lr_sched = trainer.init_optimizers(opts)
        assert len(optim) == 1 and len(lr_sched) == 1
>       assert optim[0] == opts[0][0] and \
            lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
                                frequency=1, reduce_on_plateau=False)
E       AssertionError: assert (Adam (\nParame...ght_decay: 0\n) == Adam (\nParame...ght_decay: 0\n)
E         -Adam (\n
E         -Parameter Group 0\n
E         -    amsgrad: False\n
E         -    betas: (0.9, 0.999)\n
E         -    eps: 1e-08\n
E         -    initial_lr: 0.002\n
E         -    lr: 0.002\n...
E
E         ...Full output truncated (12 lines hidden), use '-vv' to show and {'frequency':...': False, ...} == {'frequency':...7f1c2c0acc90>}
E         Omitting 4 identical items, use -vv to show
E         Left contains 1 more item:
E         {'monitor': 'val_loss'}
E         Full diff:
E           {
E            'frequency': 1,
E            'interval': 'epoch',...
E
E         ...Full output truncated (5 lines hidden), use '-vv' to show)

tests/test_gpu_models.py:123: AssertionError

@SkafteNicki
Copy link
Member Author

@williamFalcon fixed the failing test, so all checks are passing now

@williamFalcon williamFalcon merged commit 969e929 into Lightning-AI:master Mar 5, 2020
@SkafteNicki SkafteNicki mentioned this pull request Apr 2, 2020
5 tasks
tullie pushed a commit to tullie/pytorch-lightning that referenced this pull request Apr 3, 2020
* remove deprecated args to learning rate step function

* step based scheduler

* mixing models for testing

* fix styling

* tests

* update documentation

* smaller fix

* update to dict structure

* updated test

* update documentation

* update CHANGELOG.md

* fix styling

* fix problems with trainer io

* fix tests

* simplification of code

* fix styling

* change from batch to step

* update to tests

* fix styling

* fixed some logic

* Update pytorch_lightning/core/lightning.py

* duplicated test

* fix test on amp

* small update to tests

* added monitor key for ReduceLROnPlateau

* Update trainer.py

* Update training_loop.py

* fix test after introducing monitor keyword

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: William Falcon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support stepping options for lr scheduler Enable stepwise processing flag for schedulers
6 participants