Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.optim.lr_scheduler.ReduceLROnPlateau #320

Merged

Conversation

vikmary
Copy link
Contributor

@vikmary vikmary commented Oct 6, 2019

Support of torch.optim.lr_scheduler.ReduceLROnPlateau by implementing pytorch_lightning.Callback .

Fixes #298.

ReduceLROnPlateau should be inited as any other scheduler:

class MyModule(pl.LightningModule):
  def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.params.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.1,
                                                     patience=10,
                                                     min_lr=1e-6,
                                                     verbose=True)
    return [optimizer], [scheduler]

From contribution guide:

@williamFalcon
Copy link
Contributor

@vikmary looks like some of the old changes leaked into this PR. Could you rebase master into this and submit again? Thanks!

@vikmary
Copy link
Contributor Author

vikmary commented Oct 9, 2019

@vikmary looks like some of the old changes leaked into this PR. Could you rebase master into this and submit again? Thanks!

Merged origin master again.

@williamFalcon
Copy link
Contributor

ummm. still pulls in the wrong master.

can you pull your local master first then apply the rebase to your branch?

basically all the messaging and ddp stuff shouldn’t change

@vikmary
Copy link
Contributor Author

vikmary commented Oct 9, 2019

ummm. still pulls in the wrong master.

can you pull your local master first then apply the rebase to your branch?

basically all the messaging and ddp stuff shouldn’t change

That's what I did. Can you detect what's wrong?

cd my-fork
git checkout master
git pull https://github.com/williamFalcon/pytorch-lightning.git master
git checkout support-reduceonplateau-lr-scheduler
git merge master
git push

@vikmary
Copy link
Contributor Author

vikmary commented Oct 9, 2019

ummm. still pulls in the wrong master.

can you pull your local master first then apply the rebase to your branch?

basically all the messaging and ddp stuff shouldn’t change

Diff looks strange. That's i guess what you are talking about. I'll redo the merging, thanks

@williamFalcon
Copy link
Contributor

@vikmary ah it looks great now haha. I added a comment in the review

@vikmary
Copy link
Contributor Author

vikmary commented Oct 9, 2019

@williamFalcon Can't find the comment, did you submit the review?

@williamFalcon
Copy link
Contributor

ummm. ok.

My question was about why we need a custom implementation of the reduceLROnPlateu? We should be using the default PyTorch one

@williamFalcon
Copy link
Contributor

(if you click on the "Files Changed" tab you'll see my original comment and links

@vikmary
Copy link
Contributor Author

vikmary commented Oct 11, 2019

Sorry, my "Files Changes" tab is crystal clear from comments. Did you click "Add single comment" when submitting a comment or "Start a Review"? If you start a review, then you have to submit the whole review in order for comments to appear (as it seems to me).

We ARE using the default PyTorch implementation of ReduceLROnPlateau. We even can get rid of my class implementation (it only makes code cleaner).

The only problem with default PyTorch implementation it that we have to pass VALIDATION LOSS ato scheduler.step method. That is exactly what pytorch-lightning/callbacks:ReduceLROnPlateau class is doing.

Comment on lines 147 to 176
class ReduceLROnPlateauScheduler(Callback):
"""
Reduce learning rate when the monitored metric has stopped improving.
Wrapper for torch.optim.lr_schuduler.ReduceLROnPlateau learning rate
schedulers.

# Arguments
schedulers: list of torch.optim.lr_scheduler.ReduceLROnPlateau
monitor: quantity to be monitored.
"""

def __init__(self, schedulers, monitor='val_loss'):
super(ReduceLROnPlateauScheduler, self).__init__()

self.monitor = monitor
self.schedulers = schedulers

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
stop_training = False
if current is None:
print('ReduceLROnPlateau conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
exit(-1)

for scheduler in self.schedulers:
scheduler.step(current, epoch=epoch)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to create our own ReduceLROnPlateauScheduler?
We should be operating directly on the PyTorch one (https://pytorch.org/docs/stable/optim.html?highlight=reducelr#torch.optim.lr_scheduler.ReduceLROnPlateau)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReduceLROnPlateauScheduler.schedulers is a list of orginal torch.optim.lr_scheduler.ReduceLROnPlateau, see the proof in a comment below

custom_schedulers = []
i = 0
while i < len(schedulers):
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proof here

Comment on lines 807 to 810
while i < len(schedulers):
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau):
custom_schedulers.append(schedulers.pop(i))
i += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a small issue with this snippet. When the ReduceLROnPlateau optimizer is pop'ed, the i should not be increased, otherwise, the element following the element being pop'ed ends up in position schedulers[i] and then i is immediately increased, so that element never gets checked ( to see if it is another ReduceLROnPlateau.
This would not cause any problems when there is only on ReduceLROnPlateau scheduler.
While there is probably no reason to have more than one ReduceLROnPlateau scheduler, it would be nicer to change the code. Either move the i+=1 to an else branch, or ( if we make assumption that only one ReduceLROnPlateau scheduler is present ) break out of the loop.

Suggested change
while i < len(schedulers):
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau):
custom_schedulers.append(schedulers.pop(i))
i += 1
else:
i += 1

Copy link
Contributor Author

@vikmary vikmary Oct 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for the fix. I decided to support only one ReduceLROnPlateau scheduler.

@@ -1,7 +1,8 @@
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
from .pt_callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateauScheduler, GradientAccumulationScheduler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like a relative import which we shall not use... :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relative imports of EarlyStopping, ModelCheckpoint etc. are taken from the original repository. Why is relative import of ReduceLROnPlateauScheduler inappropriate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#402 never fixed relative imports in callbacks init as well as in many other places. I'd say that above comment is out of scope for this PR. @Borda it might be better to create a separate PR that will properly fix relative imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not that the PR was fixing relative imports, but I tried to make them which was stopped...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a separate PR for relative imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have opened ticket #459

@williamFalcon
Copy link
Contributor

williamFalcon commented Nov 5, 2019

@vikmary let's get this into the next release on nov 6.

  1. can you rebase master onto this?
  2. do we need that class you made? can't we just use the default pytorch one? maybe i'm missing something or reading this too quickly.

@Borda thoughts?

@vikmary
Copy link
Contributor Author

vikmary commented Nov 5, 2019

@vikmary let's get this into the next release on nov 6.

  1. can you rebase master onto this?
  2. do we need that class you made? can't we just use the default pytorch one? maybe i'm missing something or reading this too quickly.

@Borda thoughts?

I'll try to rebase during today.

@Borda
Copy link
Member

Borda commented Nov 5, 2019

there are 3 options and I would try them in the following order:

(for all three options maybe be easier to squash all your commits to one)

@vikmary
Copy link
Contributor Author

vikmary commented Nov 5, 2019

The merging needs some refactoring, will not be able to finish today

@vikmary
Copy link
Contributor Author

vikmary commented Nov 7, 2019

Hi, I refactored so that:

  • there is no Callback class, we rely solely on torch.optim.lr_scheduler.ReduceLROnPlateau
  • fixed logging error in ModelCheckoint.on_epoch_end
    @Borda @williamFalcon

@williamFalcon
Copy link
Contributor

williamFalcon commented Nov 30, 2019

@Borda @Ir1d merging this, look good (looks fine to me)?

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use exit otherwise LGTM

pytorch_lightning/trainer/train_loop_mixin.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
lr_scheduler.step(self.current_epoch)
lr_scheduler.step(epoch=self.current_epoch)
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems very specific. does it only need to work with val_loss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be any validation metric in theory, but how do we let the user pass it in? Perhaps via a dedicated dict entry in the validation_end output similar to "log" and "progress_bar"? It's one more thing the user needs to remember, but maybe its fine since this lr_scheduler is optional and it should be mentioned in the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good option. let’s do it in a separate PR?

@Ir1d
Copy link
Contributor

Ir1d commented Dec 1, 2019

I have one offtopic question..
Should this be epoch instead of epoch + 1?

https://github.com/williamFalcon/pytorch-lightning/pull/320/files#diff-f1ccb073775f3b2f9c294bd887086da3L202

@williamFalcon williamFalcon merged commit a6d64ac into Lightning-AI:master Dec 3, 2019
@vikmary
Copy link
Contributor Author

vikmary commented Dec 3, 2019

Thank you
It was a long-awaiting pull-request
Hope it has some value for framework in spite of non-trivial master integration

@williamFalcon
Copy link
Contributor

haha. Thank you so much for contributing! feel free to keep helping out as we’re always growing our core team!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support of optim.lr_scheduler.ReduceLROnPlateau
6 participants