diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 420aff36..7bd64300 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -24,7 +24,7 @@ from optax._src import numerics -def global_norm(updates: base.Updates) -> base.Updates: +def global_norm(updates: base.PyTree) -> chex.Array: """Compute the global norm across a nested structure of tensors.""" return jnp.sqrt(sum( jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))