Skip to content

Commit

Permalink
Revert pull request #329 to fix numerical instabilities.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 486636531
  • Loading branch information
hbq1 authored and OptaxDev committed Nov 7, 2022
1 parent 509c706 commit 9dcfa4a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def orderth_norm(g):

def bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
bias_correction_ = 1 - decay**count
return jax.tree_util.tree_map(
lambda t: t / (1 - decay ** count.astype(t.dtype)), moment)
lambda t: t / bias_correction_.astype(t.dtype), moment)


def _reject_complex(params):
Expand Down

0 comments on commit 9dcfa4a

Please sign in to comment.