-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reduce_on_plateau LR scheduler to contrib directory. #629
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks for the contribution @vz415 ! There seem to be some minor issues regarding formatting that break the tests (see https://github.com/google-deepmind/optax/actions/runs/6830352294). Other than that this looks good to me. Green light to merge once the tests pass @mtthss ? |
@fabianp fixed the spacing formatting issue 🤦♂️ and everything should be ready to merge. |
optax/contrib/prodigy.py
Outdated
@@ -46,7 +46,7 @@ class ProdigyState(NamedTuple): | |||
def prodigy( | |||
learning_rate: base.ScalarOrSchedule = 0.1, | |||
betas: tuple[float, float] = (0.9, 0.999), | |||
beta3: float | None = None, | |||
beta3: float = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that! I fixed it, you'll need to merge with main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, pulled the most recent main commit to fix this and pushed.
Thanks @vz415 for the changes! A couple more minor things and then I think we're ready to merge:
|
Hi @fabianp , I've addressed the issues below.
|
docs/api.rst
Outdated
@@ -619,6 +619,7 @@ Schedules | |||
.. autofunction:: piecewise_constant_schedule | |||
.. autofunction:: piecewise_interpolate_schedule | |||
.. autofunction:: polynomial_schedule | |||
.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau | |
.. autofunction:: optax.contrib.reduce_on_plateau |
optax/contrib/reduce_on_plateau.py
Outdated
min_improvement:float, | ||
cooldown:int | ||
) -> base.GradientTransformationExtraArgs: | ||
""" Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" Args: | |
"""Reduce learning rate when a metric has stopped improving. | |
Models often benefit from reducing the learning once learning stagnates. | |
This scheduler reads a metrics quantity and if no improvement is seen for | |
a ‘patience’ number of epochs, the learning rate is reduced. | |
Args: |
optax/contrib/reduce_on_plateau.py
Outdated
def reduce_on_plateau( | ||
reduce_factor: float, | ||
patience: int, | ||
min_improvement:float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min_improvement:float, | |
min_improvement: float, |
optax/contrib/reduce_on_plateau.py
Outdated
reduce_factor: float, | ||
patience: int, | ||
min_improvement:float, | ||
cooldown:int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cooldown:int | |
cooldown: int |
optax/contrib/reduce_on_plateau.py
Outdated
""" Args: | ||
reduce_factor: Factor by which the learning rate will be reduced. | ||
new_lr = lr * factor. | ||
patience: Number of epochs with no improvement after which learning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these parameters in terms of number of epochs? Wouldn't it be more accurate to talk about number of iterations? It seems to me that the method has now knowledge of epochs, but keeps track of iterations instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that. Seems like this is a carryover from the pytorch documentation. I agree that the current version should be stated in terms of iterations instead of epochs and think the function would need to change for epochs (I've only used for full-batch training on my own research) unless I'm missing something
Is there a reason why the parameters reduce_factor, patience, etc. need to be passed to the state ReduceLROnPlateauState ? |
Also, in its current state, I think this method is not straightforward to use (at least it was not clear to me). Would you be willing add a docstring or create an example on how to combine this scaling with (say) SGD or Adam? |
Thanks for the review and feedback! Answering your question: those parameters are passed to update the learning rate after every epoch. Those parameters could externalized via functools.partial but I think this format makes it easier to implement for users, although more implicit in how the updates are carried out. Let me know if this makes sense or if you have other concerns. I'll add a docstring with an example on how to use. If that's not clear I can create an example. |
This now merged with some changes. Among other things, I changed some keyword arguments to match those of the pytorch implementation. Other changes were done to fix some internal failures due to the fact that the internal pytype checks are a bit more stringent than the ones that run on the github actions. If you feel that some of these changes are for the worst, please open another PR and propose modifications. Regarding documentation, I mentioned earlier I think it is important to make an example on using this function. I created issue #679 to track progress on this |
thanks for the answer @vz415 (and for your contribution)! I saw your comment after I had submitted. In any case please don't hesitate to open another PR to edit the submitted code |
Following pull request #505 and issue #221 with @mtthss suggestions on changing where the scheduler is located (contrib) and using GradientTransformationExtraArgs API. Let me know if there's anything else that needs to be done to merge this. Cheers.