Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce_on_plateau LR scheduler to contrib directory. #629

Merged
merged 9 commits into from
Dec 19, 2023

Conversation

vz415
Copy link
Contributor

@vz415 vz415 commented Nov 10, 2023

Following pull request #505 and issue #221 with @mtthss suggestions on changing where the scheduler is located (contrib) and using GradientTransformationExtraArgs API. Let me know if there's anything else that needs to be done to merge this. Cheers.

Copy link

google-cla bot commented Nov 10, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@fabianp
Copy link
Member

fabianp commented Dec 10, 2023

Thanks for the contribution @vz415 ! There seem to be some minor issues regarding formatting that break the tests (see https://github.com/google-deepmind/optax/actions/runs/6830352294).

Other than that this looks good to me. Green light to merge once the tests pass @mtthss ?

@vz415
Copy link
Contributor Author

vz415 commented Dec 11, 2023

@fabianp fixed the spacing formatting issue 🤦‍♂️ and everything should be ready to merge.

@@ -46,7 +46,7 @@ class ProdigyState(NamedTuple):
def prodigy(
learning_rate: base.ScalarOrSchedule = 0.1,
betas: tuple[float, float] = (0.9, 0.999),
beta3: float | None = None,
beta3: float = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that! I fixed it, you'll need to merge with main.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, pulled the most recent main commit to fix this and pushed.

@fabianp
Copy link
Member

fabianp commented Dec 12, 2023

Thanks @vz415 for the changes! A couple more minor things and then I think we're ready to merge:

  1. Is there any paper we could cite that describes this approach? If so, please add them under "References" on the docstring (see examples on _src/alias.py)
  2. Please add this new method to the API reference (docs/api.rst)

@vz415
Copy link
Contributor Author

vz415 commented Dec 12, 2023

Hi @fabianp , I've addressed the issues below.

  1. There's no seminal paper on the technique but here's the pytorch documentation. I can add some papers I've seen use it if wanted.
  2. I added it to the API reference. Let me know if I need to edit it.

docs/api.rst Outdated
@@ -619,6 +619,7 @@ Schedules
.. autofunction:: piecewise_constant_schedule
.. autofunction:: piecewise_interpolate_schedule
.. autofunction:: polynomial_schedule
.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. autofunction:: optax.contrib.reduce_on_plateau.reduce_on_plateau
.. autofunction:: optax.contrib.reduce_on_plateau

min_improvement:float,
cooldown:int
) -> base.GradientTransformationExtraArgs:
""" Args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Args:
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning once learning stagnates.
This scheduler reads a metrics quantity and if no improvement is seen for
apatiencenumber of epochs, the learning rate is reduced.
Args:

def reduce_on_plateau(
reduce_factor: float,
patience: int,
min_improvement:float,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
min_improvement:float,
min_improvement: float,

reduce_factor: float,
patience: int,
min_improvement:float,
cooldown:int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cooldown:int
cooldown: int

""" Args:
reduce_factor: Factor by which the learning rate will be reduced.
new_lr = lr * factor.
patience: Number of epochs with no improvement after which learning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these parameters in terms of number of epochs? Wouldn't it be more accurate to talk about number of iterations? It seems to me that the method has now knowledge of epochs, but keeps track of iterations instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that. Seems like this is a carryover from the pytorch documentation. I agree that the current version should be stated in terms of iterations instead of epochs and think the function would need to change for epochs (I've only used for full-batch training on my own research) unless I'm missing something

@fabianp
Copy link
Member

fabianp commented Dec 18, 2023

Is there a reason why the parameters reduce_factor, patience, etc. need to be passed to the state ReduceLROnPlateauState ?

@fabianp
Copy link
Member

fabianp commented Dec 18, 2023

Also, in its current state, I think this method is not straightforward to use (at least it was not clear to me). Would you be willing add a docstring or create an example on how to combine this scaling with (say) SGD or Adam?

@vz415
Copy link
Contributor Author

vz415 commented Dec 19, 2023

Thanks for the review and feedback! Answering your question: those parameters are passed to update the learning rate after every epoch. Those parameters could externalized via functools.partial but I think this format makes it easier to implement for users, although more implicit in how the updates are carried out. Let me know if this makes sense or if you have other concerns.

I'll add a docstring with an example on how to use. If that's not clear I can create an example.

@copybara-service copybara-service bot merged commit e355cd5 into google-deepmind:master Dec 19, 2023
6 checks passed
@fabianp
Copy link
Member

fabianp commented Dec 19, 2023

This now merged with some changes. Among other things, I changed some keyword arguments to match those of the pytorch implementation.

Other changes were done to fix some internal failures due to the fact that the internal pytype checks are a bit more stringent than the ones that run on the github actions.

If you feel that some of these changes are for the worst, please open another PR and propose modifications.

Regarding documentation, I mentioned earlier I think it is important to make an example on using this function. I created issue #679 to track progress on this

@fabianp
Copy link
Member

fabianp commented Dec 19, 2023

thanks for the answer @vz415 (and for your contribution)! I saw your comment after I had submitted. In any case please don't hesitate to open another PR to edit the submitted code

@fabianp fabianp mentioned this pull request Dec 21, 2023
@vroulet vroulet mentioned this pull request Feb 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants