Skip to content

Commit

Permalink
Add tree_vdot to tree_utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581979668
  • Loading branch information
mblondel authored and OptaxDev committed Nov 20, 2023
1 parent b7ff8ac commit a3b7b8c
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 17 deletions.
19 changes: 3 additions & 16 deletions optax/contrib/mechanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,18 @@
using Mechanic to be the same for reasonably large batch sizes (>1k).
"""


import functools
import operator
from typing import NamedTuple, Optional, Tuple

import chex
import jax
import jax.numpy as jnp

from optax import tree_utils
from optax._src import base
from optax._src import utils


def _vdot_safe(a, b):
vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)
cvdot = vdot(jnp.asarray(a), jnp.asarray(b))
return cvdot


@jax.jit
def _tree_vdot(tree_x, tree_y):
"""Compute the inner product <tree_x, tree_y>."""
vdots = jax.tree_util.tree_map(_vdot_safe, tree_x, tree_y)
return jax.tree_util.tree_reduce(operator.add, vdots)


@jax.jit
def _tree_sum(tree_x):
"""Compute sum(tree_x)."""
Expand Down Expand Up @@ -193,7 +180,7 @@ def add_weight_decay(gi, pi):
)

# Now we are ready to run the actual Mechanic algorithm.
h = _tree_vdot(updates, delta_prev)
h = tree_utils.tree_vdot(updates, delta_prev)

# This clipping was not part of the original paper but we introduced it
# a little later.
Expand Down
3 changes: 2 additions & 1 deletion optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The tree_util optimisation sub-package."""
"""The tree_utils sub-package."""

from optax.tree_utils._state_utils import tree_map_params
from optax.tree_utils._tree_math import tree_vdot
54 changes: 54 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities to perform maths on pytrees."""

import functools
import operator
from typing import Any

import chex
import jax
from jax import tree_util as jtu
import jax.numpy as jnp


_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)


def _vdot_safe(a, b):
return _vdot(jnp.asarray(a), jnp.asarray(b))


def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
r"""Compute the inner product between two pytrees.
Args:
tree_x: first pytree to use.
tree_y: second pytree to use.
Returns:
inner product between ``tree_x`` and ``tree_y``, a scalar value.
>>> optax.tree_utils.tree_vdot(
>>> {a: jnp.array([1, 2]), b: jnp.array([1, 2])},
>>> {a: jnp.array([-1, -1]), b: jnp.array([1, 1])},
>>> )
0.0
Implementation detail: we upcast the values to the highest precision to avoid
numerical issues.
"""
vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y)
return jtu.tree_reduce(operator.add, vdots)
56 changes: 56 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for optax.tree_utils."""

from absl.testing import absltest

import jax.numpy as jnp
import numpy as np

from optax import tree_utils as tu


class TreeUtilsTest(absltest.TestCase):

def setUp(self):
super().setUp()
rng = np.random.RandomState(0)

self.tree_a = (rng.randn(20, 10) + 1j * rng.randn(20, 10), rng.randn(20))
self.tree_b = (rng.randn(20, 10), rng.randn(20))

self.tree_a_dict = (1.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0)
self.tree_b_dict = (1.0, {'k1': 2.0, 'k2': (3.0, 4.0)}, 5.0)

self.array_a = rng.randn(20) + 1j * rng.randn(20)
self.array_b = rng.randn(20)

def test_tree_vdot(self):
expected = jnp.vdot(self.array_a, self.array_b)
got = tu.tree_vdot(self.array_a, self.array_b)
np.testing.assert_allclose(expected, got)

expected = 15.0
got = tu.tree_vdot(self.tree_a_dict, self.tree_b_dict)
np.testing.assert_allclose(expected, got)

expected = (jnp.vdot(self.tree_a[0], self.tree_b[0]) +
jnp.vdot(self.tree_a[1], self.tree_b[1]))
got = tu.tree_vdot(self.tree_a, self.tree_b)
np.testing.assert_allclose(expected, got)

if __name__ == '__main__':
absltest.main()

0 comments on commit a3b7b8c

Please sign in to comment.