-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.optim.lr_scheduler.ReduceLROnPlateau #320
Support torch.optim.lr_scheduler.ReduceLROnPlateau #320
Conversation
@vikmary looks like some of the old changes leaked into this PR. Could you rebase master into this and submit again? Thanks! |
Merged origin master again. |
ummm. still pulls in the wrong master. can you pull your local master first then apply the rebase to your branch? basically all the messaging and ddp stuff shouldn’t change |
That's what I did. Can you detect what's wrong? cd my-fork
git checkout master
git pull https://github.com/williamFalcon/pytorch-lightning.git master
git checkout support-reduceonplateau-lr-scheduler
git merge master
git push |
Diff looks strange. That's i guess what you are talking about. I'll redo the merging, thanks |
@vikmary ah it looks great now haha. I added a comment in the review |
@williamFalcon Can't find the comment, did you submit the review? |
ummm. ok. My question was about why we need a custom implementation of the reduceLROnPlateu? We should be using the default PyTorch one |
(if you click on the "Files Changed" tab you'll see my original comment and links |
Sorry, my "Files Changes" tab is crystal clear from comments. Did you click "Add single comment" when submitting a comment or "Start a Review"? If you start a review, then you have to submit the whole review in order for comments to appear (as it seems to me). We ARE using the default PyTorch implementation of ReduceLROnPlateau. We even can get rid of my class implementation (it only makes code cleaner). The only problem with default PyTorch implementation it that we have to pass VALIDATION LOSS ato |
class ReduceLROnPlateauScheduler(Callback): | ||
""" | ||
Reduce learning rate when the monitored metric has stopped improving. | ||
Wrapper for torch.optim.lr_schuduler.ReduceLROnPlateau learning rate | ||
schedulers. | ||
|
||
# Arguments | ||
schedulers: list of torch.optim.lr_scheduler.ReduceLROnPlateau | ||
monitor: quantity to be monitored. | ||
""" | ||
|
||
def __init__(self, schedulers, monitor='val_loss'): | ||
super(ReduceLROnPlateauScheduler, self).__init__() | ||
|
||
self.monitor = monitor | ||
self.schedulers = schedulers | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
current = logs.get(self.monitor) | ||
stop_training = False | ||
if current is None: | ||
print('ReduceLROnPlateau conditioned on metric `%s` ' | ||
'which is not available. Available metrics are: %s' % | ||
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning) | ||
exit(-1) | ||
|
||
for scheduler in self.schedulers: | ||
scheduler.step(current, epoch=epoch) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to create our own ReduceLROnPlateauScheduler?
We should be operating directly on the PyTorch one (https://pytorch.org/docs/stable/optim.html?highlight=reducelr#torch.optim.lr_scheduler.ReduceLROnPlateau)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReduceLROnPlateauScheduler.schedulers
is a list of orginal torch.optim.lr_scheduler.ReduceLROnPlateau
, see the proof in a comment below
pytorch_lightning/trainer/trainer.py
Outdated
custom_schedulers = [] | ||
i = 0 | ||
while i < len(schedulers): | ||
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proof here
pytorch_lightning/trainer/trainer.py
Outdated
while i < len(schedulers): | ||
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau): | ||
custom_schedulers.append(schedulers.pop(i)) | ||
i += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a small issue with this snippet. When the ReduceLROnPlateau optimizer is pop'ed, the i should not be increased, otherwise, the element following the element being pop'ed ends up in position schedulers[i] and then i is immediately increased, so that element never gets checked ( to see if it is another ReduceLROnPlateau.
This would not cause any problems when there is only on ReduceLROnPlateau scheduler.
While there is probably no reason to have more than one ReduceLROnPlateau scheduler, it would be nicer to change the code. Either move the i+=1 to an else branch, or ( if we make assumption that only one ReduceLROnPlateau scheduler is present ) break out of the loop.
while i < len(schedulers): | |
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau): | |
custom_schedulers.append(schedulers.pop(i)) | |
i += 1 | |
else: | |
i += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thank you for the fix. I decided to support only one ReduceLROnPlateau scheduler.
@@ -1,7 +1,8 @@ | |||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler | |||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateauScheduler, GradientAccumulationScheduler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like a relative import which we shall not use... :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relative imports of EarlyStopping
, ModelCheckpoint
etc. are taken from the original repository. Why is relative import of ReduceLROnPlateauScheduler
inappropriate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was not that the PR was fixing relative imports, but I tried to make them which was stopped...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a separate PR for relative imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have opened ticket #459
there are 3 options and I would try them in the following order:
(for all three options maybe be easier to squash all your commits to one) |
The merging needs some refactoring, will not be able to finish today |
Hi, I refactored so that:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not use exit
otherwise LGTM
lr_scheduler.step(self.current_epoch) | ||
lr_scheduler.step(epoch=self.current_epoch) | ||
if self.reduce_lr_on_plateau_scheduler is not None: | ||
val_loss = self.callback_metrics.get('val_loss') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems very specific. does it only need to work with val_loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it could be any validation metric in theory, but how do we let the user pass it in? Perhaps via a dedicated dict entry in the validation_end output similar to "log" and "progress_bar"? It's one more thing the user needs to remember, but maybe its fine since this lr_scheduler is optional and it should be mentioned in the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good option. let’s do it in a separate PR?
I have one offtopic question.. |
Thank you |
haha. Thank you so much for contributing! feel free to keep helping out as we’re always growing our core team! |
Support of torch.optim.lr_scheduler.ReduceLROnPlateau by implementing pytorch_lightning.Callback .
Fixes #298.
ReduceLROnPlateau should be inited as any other scheduler:
From contribution guide: