From 0abda6e90f6cf88d31f556997880b1d94d4dd932 Mon Sep 17 00:00:00 2001 From: Eduardo Pignatelli Date: Mon, 17 Oct 2022 11:35:56 +0100 Subject: [PATCH] use `pow` rather than `pow_integer` or `bias_correction` --- optax/_src/transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 15ad3392..268730dd 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -109,9 +109,8 @@ def orderth_norm(g): def bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" - bias_correction_ = 1 - decay**count return jax.tree_util.tree_map( - lambda t: t / bias_correction_.astype(t.dtype), moment) + lambda t: t / (1 - decay ** count.astype(t.dtype)), moment) def _reject_complex(params):