Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schedule_free: fix broadcasting of scalar arrays to 1d arrays #1042

Merged
merged 7 commits into from
Sep 4, 2024

Conversation

n-gao
Copy link
Contributor

@n-gao n-gao commented Sep 2, 2024

Currently, the momentum is stored in a 1D array [b1] of shape (1,). We should store it instead in a scalar array ()to avoid broadcasting scalars to (1,) in schedule_free_eval_params.

Before:

opt = optax.contrib.schedule_free_adamw()
x = jnp.ones(())
state = opt.init(x)
optax.contrib.schedule_free_eval_params(state, x).shape
# (1,)

After:

opt = optax.contrib.schedule_free_adamw()
x = jnp.ones(())
state = opt.init(x)
optax.contrib.schedule_free_eval_params(state, x).shape
# ()

@fabianp
Copy link
Member

fabianp commented Sep 2, 2024

your solution seems reasonable to me. However, there are now some doctest errors in optax/contrib/_schedule_free.py, probably due to them now being included in the docs.

@n-gao
Copy link
Contributor Author

n-gao commented Sep 2, 2024

@fabianp I am quite unfamiliar with the docs. I thought this would be a simple change. Is there some documentation on this? Otherwise, I can also remove the doc changes.

@fabianp
Copy link
Member

fabianp commented Sep 2, 2024

you might not need to build the docs (although if you wanted to, its described in the README). Just check the errors from the failing CI (https://github.com/google-deepmind/optax/actions/runs/10664454490/job/29555745356?pr=1042). as you can see, the issue seems to be that some examples in the docstrings use schedule_free_eval_params instead of the full name optax.contrib.schedule...

let me know if this doesn't make sense

@n-gao
Copy link
Contributor Author

n-gao commented Sep 2, 2024

I haven't touched any line related to schedule_free_eval_params. All the other lines that docs/api/contrib.rst also don't use complete paths? Let me check if the tests pass if I remove the lines again.

@n-gao
Copy link
Contributor Author

n-gao commented Sep 2, 2024

I don't get why the tests are failing. it works locally and the change seems unrelated. @fabianp do you have another idea?

@fabianp
Copy link
Member

fabianp commented Sep 2, 2024

you can undo the changes in docs/api/contrib.rst if you want since they are orthogonal to this PR

@n-gao
Copy link
Contributor Author

n-gao commented Sep 2, 2024

done

@fabianp
Copy link
Member

fabianp commented Sep 2, 2024

can you also add a test showing that the new approach doesn't have the broadcasting problem?

@n-gao
Copy link
Contributor Author

n-gao commented Sep 2, 2024

I added a test that fails before and succeeds after the PR

@@ -164,5 +164,16 @@ def run(opt):
params_wrapper = run(opt_wrapper)
chex.assert_trees_all_close(params_shortcut, params_wrapper)

@parameterized.parameters(*_OPTIMIZERS_UNDER_TEST)
def test_scalar_preservance(self, opt_name, opt_kwargs):
opt = getattr(alias, opt_name)(learning_rate=0.0, **opt_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps call this base_opt ? otherwise both the base and the wrapper have the same name which is confusing

@@ -164,5 +164,16 @@ def run(opt):
params_wrapper = run(opt_wrapper)
chex.assert_trees_all_close(params_shortcut, params_wrapper)

@parameterized.parameters(*_OPTIMIZERS_UNDER_TEST)
def test_scalar_preservance(self, opt_name, opt_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name test_scalar_preservance is not very descriptive. Please either use a more precise name for the test or add a comment below explaining what behavior the function is testing

@n-gao
Copy link
Contributor Author

n-gao commented Sep 3, 2024

Added a comment and changed the variable name. Though, this criticism probably applies to all the other tests in that file.

@fabianp
Copy link
Member

fabianp commented Sep 3, 2024

excellent, thanks!

@copybara-service copybara-service bot merged commit 896cb88 into google-deepmind:main Sep 4, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants