Skip to content

Commit

Permalink
Merge pull request #1000 from fabianp:SauravMaheshkar-saurav/scale_by…
Browse files Browse the repository at this point in the history
…_grad_norm

PiperOrigin-RevId: 650548500
  • Loading branch information
OptaxDev committed Jul 9, 2024
2 parents 648a967 + 2660a04 commit 5b55ccc
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Transformations
identity
keep_params_nonnegative
NonNegativeParamsState
normalize_by_update_norm
OptState
Params
per_example_global_norm_clip
Expand Down Expand Up @@ -159,6 +160,8 @@ Transformations and states
.. autofunction:: keep_params_nonnegative
.. autoclass:: NonNegativeParamsState

.. autofunction:: normalize_by_update_norm

.. autofunction:: per_example_global_norm_clip
.. autofunction:: per_example_layer_norm_clip

Expand Down
1 change: 1 addition & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
Expand Down
48 changes: 48 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,54 @@ def update_fn(
return base.GradientTransformation(init_fn, update_fn)


def normalize_by_update_norm(
scale_factor: float = 1.0, eps: float = 1e-6
) -> base.GradientTransformation:
"""Scale by the inverse of the update norm.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.normalize_by_update_norm(scale_factor=-1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 14.0
Objective function: 7.52E+00
Objective function: 3.03E+00
Objective function: 5.50E-01
Objective function: 6.67E-02
Objective function: 5.50E-01
Args:
scale_factor: factor by which the update will be multiplied (defaults to 1).
eps: jitter term to avoid dividing by 0
Returns:
A `GradientTransformation` object.
"""

def update_fn(
updates: base.Updates,
state: base.EmptyState,
params: Optional[base.Params] = None,
) -> tuple[base.Updates, base.EmptyState]:
del params
g_norm = (otu.tree_l2_norm(updates) + eps) / scale_factor
updates = jtu.tree_map(lambda g: g / g_norm, updates)
return updates, state

return base.GradientTransformation(base.init_empty_state, update_fn)


### Legacy symbols to be removed. ###


Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def setUp(self):
('param_block_norm', transform.scale_by_param_block_norm),
('param_block_rms', transform.scale_by_param_block_rms),
('distance_over_gradients', transform.scale_by_distance_over_gradients),
('normalize_by_update_norm', transform.normalize_by_update_norm),
])
def test_scalers(self, scaler_constr):
params = self.init_params
Expand Down

0 comments on commit 5b55ccc

Please sign in to comment.