diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index c670c626..da3e3728 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -32,6 +32,7 @@ Transformations identity keep_params_nonnegative NonNegativeParamsState + normalize_by_update_norm OptState Params per_example_global_norm_clip @@ -159,6 +160,8 @@ Transformations and states .. autofunction:: keep_params_nonnegative .. autoclass:: NonNegativeParamsState +.. autofunction:: normalize_by_update_norm + .. autofunction:: per_example_global_norm_clip .. autofunction:: per_example_layer_norm_clip diff --git a/optax/__init__.py b/optax/__init__.py index f0446937..9679775b 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -110,6 +110,7 @@ from optax._src.transform import centralize from optax._src.transform import ema from optax._src.transform import EmaState +from optax._src.transform import normalize_by_update_norm from optax._src.transform import scale from optax._src.transform import scale_by_adadelta from optax._src.transform import scale_by_adam diff --git a/optax/_src/transform.py b/optax/_src/transform.py index b8a5bbd1..8b85d940 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1492,6 +1492,54 @@ def update_fn( return base.GradientTransformation(init_fn, update_fn) +def normalize_by_update_norm( + scale_factor: float = 1.0, eps: float = 1e-6 +) -> base.GradientTransformation: + """Scale by the inverse of the update norm. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 14.0 + Objective function: 7.52E+00 + Objective function: 3.03E+00 + Objective function: 5.50E-01 + Objective function: 6.67E-02 + Objective function: 5.50E-01 + + Args: + scale_factor: factor by which the update will be multiplied (defaults to 1). + eps: jitter term to avoid dividing by 0 + + Returns: + A `GradientTransformation` object. + """ + + def update_fn( + updates: base.Updates, + state: base.EmptyState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, base.EmptyState]: + del params + g_norm = (otu.tree_l2_norm(updates) + eps) / scale_factor + updates = jtu.tree_map(lambda g: g / g_norm, updates) + return updates, state + + return base.GradientTransformation(base.init_empty_state, update_fn) + + ### Legacy symbols to be removed. ### diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index a6e5d6fc..afac35ee 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -52,6 +52,7 @@ def setUp(self): ('param_block_norm', transform.scale_by_param_block_norm), ('param_block_rms', transform.scale_by_param_block_rms), ('distance_over_gradients', transform.scale_by_distance_over_gradients), + ('normalize_by_update_norm', transform.normalize_by_update_norm), ]) def test_scalers(self, scaler_constr): params = self.init_params