-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 604411756
- Loading branch information
Showing
6 changed files
with
389 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
As the name suggests, the functions implemented here are subject to | ||
modifications. We are currently developing a new Solver API that could | ||
span more optimizers such as the ones using some linesearches. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "vfWSk55u5_E-" | ||
}, | ||
"source": [ | ||
"# Example of Solver API usage\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "vzyIF6NW6Dwd" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax.numpy as jnp\n", | ||
"import optax\n", | ||
"from optax.experimental import gradient_solver\n", | ||
"\n", | ||
"def obj_fun(x):\n", | ||
" return jnp.sum(x**2)\n", | ||
"\n", | ||
"init, step = gradient_solver.gradient_solver(\n", | ||
" obj_fun, optax.adam(learning_rate=1.)\n", | ||
" )\n", | ||
"\n", | ||
"params = jnp.arange(16, dtype=jnp.float32)\n", | ||
"state = init(params)\n", | ||
"for _ in range(10):\n", | ||
" params, state = step(params, state)\n", | ||
" print(f'Objective value: {obj_fun(params)}')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"last_runtime": { | ||
"build_target": "//learning/grp/tools/ml_python:ml_notebook", | ||
"kind": "private" | ||
}, | ||
"private_outputs": true, | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Wraps a GradientTransform into a Solver.""" | ||
|
||
from typing import Any, NamedTuple, Union | ||
|
||
import jax | ||
import optax | ||
import optax.experimental.solver as optax_solver | ||
import optax.experimental.utils as exp_utils | ||
|
||
|
||
class GradientSolverState(NamedTuple): | ||
gt_state: optax.OptState = None | ||
|
||
|
||
def gradient_solver(obj_fn, gradient_transform, obj_fun_has_aux=False): | ||
"""Wraps a GradientTransform into a Solver.""" | ||
|
||
def init(init_params: optax.Params) -> optax_solver.SolverState: | ||
init_gt_state = gradient_transform.init(init_params) | ||
init_opt_state = GradientSolverState(init_gt_state) | ||
return init_opt_state | ||
|
||
def step( | ||
params: optax.Params, | ||
state: optax_solver.SolverState, | ||
**extra_kwargs: dict[str, Any] | ||
) -> tuple[ | ||
Union[optax.Params, tuple[optax.Params, Any]], optax_solver.SolverState | ||
]: | ||
obj_kwargs, gt_kwargs = exp_utils.split_kwargs( | ||
(obj_fn, gradient_transform.update), extra_kwargs | ||
) | ||
if obj_fun_has_aux: | ||
grad, aux = jax.grad(obj_fn, has_aux=obj_fun_has_aux)( | ||
params, **obj_kwargs | ||
) | ||
else: | ||
grad = jax.grad(obj_fn)(params, **obj_kwargs) | ||
aux = None | ||
update, gt_state = gradient_transform.update( | ||
grad, state.gt_state, params, **gt_kwargs | ||
) | ||
next_params = optax.apply_updates(params, update) | ||
next_state = GradientSolverState(gt_state) | ||
if obj_fun_has_aux: | ||
return (next_params, aux), next_state | ||
else: | ||
return next_params, next_state | ||
|
||
return optax_solver.Solver(init, step) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `alias.py`.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from optax._src import alias | ||
from optax._src import numerics | ||
from optax.experimental import gradient_solver | ||
|
||
|
||
_GRAD_TRANSFORMS_UNDER_TEST = ( | ||
dict(gt_name='sgd', gt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), | ||
dict(gt_name='adafactor', gt_kwargs=dict(learning_rate=5e-3)), | ||
dict(gt_name='adagrad', gt_kwargs=dict(learning_rate=1.0)), | ||
dict(gt_name='adam', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='adamw', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='adamax', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='adamaxw', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='amsgrad', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='lars', gt_kwargs=dict(learning_rate=1.0)), | ||
dict(gt_name='lamb', gt_kwargs=dict(learning_rate=1e-3)), | ||
dict( | ||
gt_name='lion', gt_kwargs=dict(learning_rate=1e-2, weight_decay=1e-4), | ||
), | ||
dict(gt_name='nadam', gt_kwargs=dict(learning_rate=1e-2)), | ||
dict(gt_name='nadamw', gt_kwargs=dict(learning_rate=1e-2)), | ||
dict(gt_name='noisy_sgd', gt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), | ||
dict(gt_name='novograd', gt_kwargs=dict(learning_rate=1e-3)), | ||
dict( | ||
gt_name='optimistic_gradient_descent', | ||
gt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1), | ||
), | ||
dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3)), | ||
dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3, momentum=0.9)), | ||
dict(gt_name='fromage', gt_kwargs=dict(learning_rate=5e-3)), | ||
dict(gt_name='adabelief', gt_kwargs=dict(learning_rate=1e-2)), | ||
dict(gt_name='radam', gt_kwargs=dict(learning_rate=5e-3)), | ||
dict(gt_name='rprop', gt_kwargs=dict(learning_rate=1e-1)), | ||
dict(gt_name='sm3', gt_kwargs=dict(learning_rate=1.0)), | ||
dict(gt_name='yogi', gt_kwargs=dict(learning_rate=1e-1)), | ||
) | ||
|
||
|
||
def _setup_parabola(dtype): | ||
"""Quadratic function as an optimization target.""" | ||
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) | ||
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) | ||
|
||
def obj_fun(params): | ||
return jnp.sum(numerics.abs_sq(params - final_params)) | ||
|
||
return initial_params, final_params, obj_fun | ||
|
||
|
||
def _setup_rosenbrock(dtype): | ||
"""Rosenbrock function as an optimization target.""" | ||
a = 1.0 | ||
b = 100.0 | ||
|
||
initial_params = jnp.array([0.0, 0.0], dtype=dtype) | ||
final_params = jnp.array([a, a**2], dtype=dtype) | ||
|
||
def obj_fun(params): | ||
return (numerics.abs_sq(a - params[0]) + | ||
b * numerics.abs_sq(params[1] - params[0]**2)) | ||
|
||
return initial_params, final_params, obj_fun | ||
|
||
|
||
class SolverWrapperTest(chex.TestCase): | ||
|
||
@parameterized.product( | ||
_GRAD_TRANSFORMS_UNDER_TEST, | ||
target=(_setup_parabola, _setup_rosenbrock), | ||
dtype=(jnp.float32,), | ||
) | ||
def test_optimization(self, gt_name, gt_kwargs, target, dtype): | ||
opt = getattr(alias, gt_name)(**gt_kwargs) | ||
initial_params, final_params, obj_fun = target(dtype) | ||
|
||
init, step = gradient_solver.gradient_solver(obj_fun, opt) | ||
|
||
params = initial_params | ||
state = init(params) | ||
step = jax.jit(step) | ||
for _ in range(10_000): | ||
params, state = step(params, state) | ||
|
||
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Solver API.""" | ||
|
||
from typing import Any, NamedTuple, Protocol, Union | ||
|
||
import optax | ||
|
||
Params = optax.Params | ||
SolverState = Any | ||
|
||
|
||
class SolverInitFn(Protocol): | ||
"""A callable type for the `init` function of a `Solver`. | ||
The `init` function takes a tree of `params` and uses these to construct an | ||
arbitrary structured initial `state` for the solver. This | ||
may hold statistics of the past updates or any other non static information. | ||
""" | ||
|
||
def __call__(self, params: Params) -> SolverState: | ||
"""Initialize the solver. | ||
Args: | ||
params: The initial value of the parameters. | ||
Returns: | ||
The initial state of the solver. | ||
""" | ||
|
||
|
||
class SolverStepFn(Protocol): | ||
"""A callable type for the `step` function of a `Solver`. | ||
The `step` function takes a tree of candidate parameters `params`, and an | ||
arbitrary structured `state` to return a new tree of candidate parameters, | ||
and a new state. Additional arguments can be fed in a keyword format. | ||
""" | ||
|
||
def __call__( | ||
self, params: Params, state: SolverState, **extra_kwargs: dict[str, Any] | ||
) -> tuple[Union[Params, tuple[Params, Any]], SolverState]: | ||
"""Performs a step of the solver. | ||
Args: | ||
params: A tree of candidate parameters. | ||
state: The state of the solver. | ||
**extra_kwargs: Additional arguments for the function or the solver in | ||
keyword format. | ||
Returns: | ||
The updated parameters, eventually with an auxiliary output, | ||
and updated state. | ||
""" | ||
|
||
|
||
class Solver(NamedTuple): | ||
"""A pair of pure functions implementing a solver. | ||
The init function initializes the state of the solver given an initial tree of | ||
parameters. The step function updates the parameters and the state of the | ||
solver given current parameters and state. | ||
Contrarily to GradientTransformation, this API accesses the function to be | ||
optimized directly to compute gradients, then update directions and finally | ||
updated parameters. | ||
Attributes: | ||
init: A pure function which, when called with an example instance of the | ||
parameters, returns an arbitrary structured initial `state`. | ||
step: A pure function which takes as input a tree of parameters, the | ||
previous solver state (which may have been initialized using the init | ||
function). The step function then returns the updated parameters, | ||
and a new solver state. | ||
""" | ||
|
||
init_fn: SolverInitFn | ||
step_fn: SolverStepFn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Utilities for solvers. | ||
""" | ||
import functools | ||
import inspect | ||
import operator | ||
|
||
from typing import Any, Callable, Sequence | ||
|
||
|
||
def split_kwargs( | ||
funs: Sequence[Callable[..., Any]], | ||
fun_kwargs: dict[str, Any], | ||
) -> Sequence[dict[str, Any]]: | ||
"""Split fun_kwargs into kwargs of the input functions funs. | ||
Raises an error in one keyword argument of fun_kwargs does not match any | ||
argument name of funs. | ||
Args: | ||
funs: sequence of functions to feed fun_kwargs to | ||
fun_kwargs: dictionary of keyword variables to be fed to funs | ||
Returns: | ||
(fun_1_kwargs, ..., fun_n_kwargs): keyword arguments for each function taken | ||
from fun_kwargs. | ||
Examples: | ||
>>> def fun1(a, b): return a+b | ||
>>> def fun2(c, d): return c+d | ||
>>> fun_kwargs = {'b':1., 'd':2.} | ||
>>> funs_kwargs = split_kwargs((fun1, fun2), fun_kwargs) | ||
>>> print(funs_kwargs) | ||
[{'b': 1.0}, {'d': 2.0}] | ||
""" | ||
funs_arg_names = [ | ||
list(inspect.signature(fun).parameters.keys()) for fun in funs | ||
] | ||
funs_kwargs = [ | ||
{k: v for k, v in fun_kwargs.items() if k in fun_arg_names} | ||
for fun_arg_names in funs_arg_names | ||
] | ||
all_possible_arg_names = functools.reduce(operator.add, funs_arg_names) | ||
remaining_keys = [ | ||
k for k in fun_kwargs.keys() if k not in all_possible_arg_names | ||
] | ||
if remaining_keys: | ||
raise ValueError( | ||
f'{remaining_keys} are not valid arguments for any of the functions' | ||
f' {funs}' | ||
) | ||
return funs_kwargs |