-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
schedule_free: fix broadcasting of scalar arrays to 1d arrays #1042
schedule_free: fix broadcasting of scalar arrays to 1d arrays #1042
Conversation
your solution seems reasonable to me. However, there are now some doctest errors in |
@fabianp I am quite unfamiliar with the docs. I thought this would be a simple change. Is there some documentation on this? Otherwise, I can also remove the doc changes. |
you might not need to build the docs (although if you wanted to, its described in the README). Just check the errors from the failing CI (https://github.com/google-deepmind/optax/actions/runs/10664454490/job/29555745356?pr=1042). as you can see, the issue seems to be that some examples in the docstrings use let me know if this doesn't make sense |
I haven't touched any line related to |
I don't get why the tests are failing. it works locally and the change seems unrelated. @fabianp do you have another idea? |
you can undo the changes in docs/api/contrib.rst if you want since they are orthogonal to this PR |
done |
can you also add a test showing that the new approach doesn't have the broadcasting problem? |
I added a test that fails before and succeeds after the PR |
optax/contrib/_schedule_free_test.py
Outdated
@@ -164,5 +164,16 @@ def run(opt): | |||
params_wrapper = run(opt_wrapper) | |||
chex.assert_trees_all_close(params_shortcut, params_wrapper) | |||
|
|||
@parameterized.parameters(*_OPTIMIZERS_UNDER_TEST) | |||
def test_scalar_preservance(self, opt_name, opt_kwargs): | |||
opt = getattr(alias, opt_name)(learning_rate=0.0, **opt_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps call this base_opt ? otherwise both the base and the wrapper have the same name which is confusing
@@ -164,5 +164,16 @@ def run(opt): | |||
params_wrapper = run(opt_wrapper) | |||
chex.assert_trees_all_close(params_shortcut, params_wrapper) | |||
|
|||
@parameterized.parameters(*_OPTIMIZERS_UNDER_TEST) | |||
def test_scalar_preservance(self, opt_name, opt_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name test_scalar_preservance
is not very descriptive. Please either use a more precise name for the test or add a comment below explaining what behavior the function is testing
Added a comment and changed the variable name. Though, this criticism probably applies to all the other tests in that file. |
excellent, thanks! |
Currently, the momentum is stored in a 1D array
[b1]
of shape(1,)
. We should store it instead in a scalar array()
to avoid broadcasting scalars to(1,)
inschedule_free_eval_params
.Before:
After: