Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mathematical documentation to AdaMax #918

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,7 +1581,37 @@ def adamax(
b2: float = 0.999,
eps: float = 1e-8,
) -> base.GradientTransformation:
"""A variant of the Adam optimizer that uses the infinity norm.
r"""A variant of the Adam optimizer that uses the infinity norm.

AdaMax is a variant of the :func:`optax.adam` optimizer. By generalizing
Adam's :math:`L^2` norm to an :math:`L^p` norm and taking the limit as
:math:`p \rightarrow \infty`, we obtain a simple and stable update rule.

Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon` represent the arguments
``b1``, ``b2`` and ``eps`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function.

The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t` and optimizer state :math:`S_t`
and computes updates :math:`u_t` and new state :math:`S_{t+1}`. Thus, for
:math:`t > 0`, we have,

.. math::

\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot
v_{t-1}) \\
\hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\
u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\
S_t &\leftarrow (m_t, v_t).
\end{align*}

Examples:
>>> import optax
Expand Down
Loading