-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytree-based Optimizers #432
Comments
@mkunesch @rosshemsley @hbq1 let us know what you think 👍 |
@cgarciae Sorry for taking a while to respond! And thanks for sharing your design proposal! There have been a few projects recently working on attaching functions to custom pytrees in JAX (e.g. equinox), and this proposal has seems to have some similar ideas. For the reasons you mentioned, this factoring can be attractive! Although it's worth highlighting that there are some downsides to this approach, too:
We are continuing to work on polishing the optax API - although we are also deliberately being conservative about the changes we make - part of what makes optax successful is its ruthless simplicity, and forking the API with two sets of alternative factorings would increase the API surface area and could make it harder for us to support. We're currently working on improving the package factoring, which will hopefully leave optax in a better place for trying out some more experimental ideas (such as this kind of API factoring), but it may be a little while before we would want to introduce big changes like this to the core library. We'd encourage you to keep thinking about this idea though! Especially with regards to 2) above. Optax has thousands of users at the moment, and so charting a path forwards whilst retaining checkpoint compatibility is probably the biggest barrier we have to making these kinds of changes. It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad) can you break this design through unexpected jit placement? (as a rule, someone has done one of these things somewhere to all optax optimizers already) |
Just some thoughts:
This proposal doesn't change anything for users since the interface is identical except for the four benefits mentioned. The reasoning that users have to do is exactly the same. If anything, the user reasoning is simpler since the sequence interface is not exposed. It's unfortunate that we didn't reconcile this issue back when it was suggested in the very first Optax issue.
Why don't you try breaking it? I think it might make the benefits more apparent. I also think it would be good to at least block the sequence interface, which are misuses of the current optax design. This will make it easier to improve your design in the future. |
This topic is in my mind every once in a while, it has already been discussed extensively (e.g. #197 (comment)), but I feel it needs new life because it could resolve the last remaining quirks in optax.
Optax optimizers have well defined API and contrary to neural networks they have clear ways on how to update their state, making them perfectly suitable for pytree/dataclass interfaces. Similar to what @NeilGirdhar has done here, one could express Pytree version of all optimizers by wrapping functional optax with the added benefits:
jit
..replace()
.opt_state
separation.Example
For this example I'l be using Flax's
PyTreeNode
but any pytree implementation is just as good.Proposal
Given that any community shim will probably not succeed, how about a
optax.pytree
namespace (naming suggestions are welcomed) where a shim could officially live and be discussed with the core team?The text was updated successfully, but these errors were encountered: