diff --git a/optax/experimental/README.md b/optax/experimental/README.md new file mode 100644 index 00000000..2251023e --- /dev/null +++ b/optax/experimental/README.md @@ -0,0 +1,3 @@ +As the name suggests, the functions implemented here are subject to +modifications. We are currently developing a new Solver API that could +span more optimizers such as the ones using some linesearches. diff --git a/optax/experimental/api_test.ipynb b/optax/experimental/api_test.ipynb new file mode 100644 index 00000000..103a44f2 --- /dev/null +++ b/optax/experimental/api_test.ipynb @@ -0,0 +1,58 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "vfWSk55u5_E-" + }, + "source": [ + "# Example of Solver API usage\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vzyIF6NW6Dwd" + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import optax\n", + "from optax.experimental import gradient_solver\n", + "\n", + "def obj_fun(x):\n", + " return jnp.sum(x**2)\n", + "\n", + "init, step = gradient_solver.gradient_solver(\n", + " obj_fun, optax.adam(learning_rate=1.)\n", + " )\n", + "\n", + "params = jnp.arange(16, dtype=jnp.float32)\n", + "state = init(params)\n", + "for _ in range(10):\n", + " params, state = step(params, state)\n", + " print(f'Objective value: {obj_fun(params)}')" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/optax/experimental/gradient_solver.py b/optax/experimental/gradient_solver.py new file mode 100644 index 00000000..2fcdc2cc --- /dev/null +++ b/optax/experimental/gradient_solver.py @@ -0,0 +1,64 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps a GradientTransform into a Solver.""" + +from typing import Any, NamedTuple, Union + +import jax +import optax +import optax.experimental.solver as optax_solver +import optax.experimental.utils as exp_utils + + +class GradientSolverState(NamedTuple): + gt_state: optax.OptState = None + + +def gradient_solver(obj_fn, gradient_transform, obj_fun_has_aux=False): + """Wraps a GradientTransform into a Solver.""" + + def init(init_params: optax.Params) -> optax_solver.SolverState: + init_gt_state = gradient_transform.init(init_params) + init_opt_state = GradientSolverState(init_gt_state) + return init_opt_state + + def step( + params: optax.Params, + state: optax_solver.SolverState, + **extra_kwargs: dict[str, Any] + ) -> tuple[ + Union[optax.Params, tuple[optax.Params, Any]], optax_solver.SolverState + ]: + obj_kwargs, gt_kwargs = exp_utils.split_kwargs( + (obj_fn, gradient_transform.update), extra_kwargs + ) + if obj_fun_has_aux: + grad, aux = jax.grad(obj_fn, has_aux=obj_fun_has_aux)( + params, **obj_kwargs + ) + else: + grad = jax.grad(obj_fn)(params, **obj_kwargs) + aux = None + update, gt_state = gradient_transform.update( + grad, state.gt_state, params, **gt_kwargs + ) + next_params = optax.apply_updates(params, update) + next_state = GradientSolverState(gt_state) + if obj_fun_has_aux: + return (next_params, aux), next_state + else: + return next_params, next_state + + return optax_solver.Solver(init, step) diff --git a/optax/experimental/gradient_solver_test.py b/optax/experimental/gradient_solver_test.py new file mode 100644 index 00000000..1ff0473e --- /dev/null +++ b/optax/experimental/gradient_solver_test.py @@ -0,0 +1,110 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `alias.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp + +from optax._src import alias +from optax._src import numerics +from optax.experimental import gradient_solver + + +_GRAD_TRANSFORMS_UNDER_TEST = ( + dict(gt_name='sgd', gt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), + dict(gt_name='adafactor', gt_kwargs=dict(learning_rate=5e-3)), + dict(gt_name='adagrad', gt_kwargs=dict(learning_rate=1.0)), + dict(gt_name='adam', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='adamw', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='adamax', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='adamaxw', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='amsgrad', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='lars', gt_kwargs=dict(learning_rate=1.0)), + dict(gt_name='lamb', gt_kwargs=dict(learning_rate=1e-3)), + dict( + gt_name='lion', gt_kwargs=dict(learning_rate=1e-2, weight_decay=1e-4), + ), + dict(gt_name='nadam', gt_kwargs=dict(learning_rate=1e-2)), + dict(gt_name='nadamw', gt_kwargs=dict(learning_rate=1e-2)), + dict(gt_name='noisy_sgd', gt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), + dict(gt_name='novograd', gt_kwargs=dict(learning_rate=1e-3)), + dict( + gt_name='optimistic_gradient_descent', + gt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1), + ), + dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3)), + dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3, momentum=0.9)), + dict(gt_name='fromage', gt_kwargs=dict(learning_rate=5e-3)), + dict(gt_name='adabelief', gt_kwargs=dict(learning_rate=1e-2)), + dict(gt_name='radam', gt_kwargs=dict(learning_rate=5e-3)), + dict(gt_name='rprop', gt_kwargs=dict(learning_rate=1e-1)), + dict(gt_name='sm3', gt_kwargs=dict(learning_rate=1.0)), + dict(gt_name='yogi', gt_kwargs=dict(learning_rate=1e-1)), +) + + +def _setup_parabola(dtype): + """Quadratic function as an optimization target.""" + initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) + final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) + + def obj_fun(params): + return jnp.sum(numerics.abs_sq(params - final_params)) + + return initial_params, final_params, obj_fun + + +def _setup_rosenbrock(dtype): + """Rosenbrock function as an optimization target.""" + a = 1.0 + b = 100.0 + + initial_params = jnp.array([0.0, 0.0], dtype=dtype) + final_params = jnp.array([a, a**2], dtype=dtype) + + def obj_fun(params): + return (numerics.abs_sq(a - params[0]) + + b * numerics.abs_sq(params[1] - params[0]**2)) + + return initial_params, final_params, obj_fun + + +class SolverWrapperTest(chex.TestCase): + + @parameterized.product( + _GRAD_TRANSFORMS_UNDER_TEST, + target=(_setup_parabola, _setup_rosenbrock), + dtype=(jnp.float32,), + ) + def test_optimization(self, gt_name, gt_kwargs, target, dtype): + opt = getattr(alias, gt_name)(**gt_kwargs) + initial_params, final_params, obj_fun = target(dtype) + + init, step = gradient_solver.gradient_solver(obj_fun, opt) + + params = initial_params + state = init(params) + step = jax.jit(step) + for _ in range(10_000): + params, state = step(params, state) + + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + +if __name__ == '__main__': + absltest.main() diff --git a/optax/experimental/solver.py b/optax/experimental/solver.py new file mode 100644 index 00000000..1fda3393 --- /dev/null +++ b/optax/experimental/solver.py @@ -0,0 +1,89 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Solver API.""" + +from typing import Any, NamedTuple, Protocol, Union + +import optax + +Params = optax.Params +SolverState = Any + + +class SolverInitFn(Protocol): + """A callable type for the `init` function of a `Solver`. + + The `init` function takes a tree of `params` and uses these to construct an + arbitrary structured initial `state` for the solver. This + may hold statistics of the past updates or any other non static information. + """ + + def __call__(self, params: Params) -> SolverState: + """Initialize the solver. + + Args: + params: The initial value of the parameters. + + Returns: + The initial state of the solver. + """ + + +class SolverStepFn(Protocol): + """A callable type for the `step` function of a `Solver`. + + The `step` function takes a tree of candidate parameters `params`, and an + arbitrary structured `state` to return a new tree of candidate parameters, + and a new state. Additional arguments can be fed in a keyword format. + """ + + def __call__( + self, params: Params, state: SolverState, **extra_kwargs: dict[str, Any] + ) -> tuple[Union[Params, tuple[Params, Any]], SolverState]: + """Performs a step of the solver. + + Args: + params: A tree of candidate parameters. + state: The state of the solver. + **extra_kwargs: Additional arguments for the function or the solver in + keyword format. + + Returns: + The updated parameters, eventually with an auxiliary output, + and updated state. + """ + + +class Solver(NamedTuple): + """A pair of pure functions implementing a solver. + + The init function initializes the state of the solver given an initial tree of + parameters. The step function updates the parameters and the state of the + solver given current parameters and state. + Contrarily to GradientTransformation, this API accesses the function to be + optimized directly to compute gradients, then update directions and finally + updated parameters. + + Attributes: + init: A pure function which, when called with an example instance of the + parameters, returns an arbitrary structured initial `state`. + step: A pure function which takes as input a tree of parameters, the + previous solver state (which may have been initialized using the init + function). The step function then returns the updated parameters, + and a new solver state. + """ + + init_fn: SolverInitFn + step_fn: SolverStepFn diff --git a/optax/experimental/utils.py b/optax/experimental/utils.py new file mode 100644 index 00000000..5f386358 --- /dev/null +++ b/optax/experimental/utils.py @@ -0,0 +1,65 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for solvers. +""" +import functools +import inspect +import operator + +from typing import Any, Callable, Sequence + + +def split_kwargs( + funs: Sequence[Callable[..., Any]], + fun_kwargs: dict[str, Any], +) -> Sequence[dict[str, Any]]: + """Split fun_kwargs into kwargs of the input functions funs. + + Raises an error in one keyword argument of fun_kwargs does not match any + argument name of funs. + + Args: + funs: sequence of functions to feed fun_kwargs to + fun_kwargs: dictionary of keyword variables to be fed to funs + + Returns: + (fun_1_kwargs, ..., fun_n_kwargs): keyword arguments for each function taken + from fun_kwargs. + + Examples: + >>> def fun1(a, b): return a+b + >>> def fun2(c, d): return c+d + >>> fun_kwargs = {'b':1., 'd':2.} + >>> funs_kwargs = split_kwargs((fun1, fun2), fun_kwargs) + >>> print(funs_kwargs) + [{'b': 1.0}, {'d': 2.0}] + """ + funs_arg_names = [ + list(inspect.signature(fun).parameters.keys()) for fun in funs + ] + funs_kwargs = [ + {k: v for k, v in fun_kwargs.items() if k in fun_arg_names} + for fun_arg_names in funs_arg_names + ] + all_possible_arg_names = functools.reduce(operator.add, funs_arg_names) + remaining_keys = [ + k for k in fun_kwargs.keys() if k not in all_possible_arg_names + ] + if remaining_keys: + raise ValueError( + f'{remaining_keys} are not valid arguments for any of the functions' + f' {funs}' + ) + return funs_kwargs