-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 581979668
- Loading branch information
Showing
4 changed files
with
115 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Utilities to perform maths on pytrees.""" | ||
|
||
import functools | ||
import operator | ||
from typing import Any | ||
|
||
import chex | ||
import jax | ||
from jax import tree_util as jtu | ||
import jax.numpy as jnp | ||
|
||
|
||
_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST) | ||
|
||
|
||
def _vdot_safe(a, b): | ||
return _vdot(jnp.asarray(a), jnp.asarray(b)) | ||
|
||
|
||
def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: | ||
r"""Compute the inner product between two pytrees. | ||
Args: | ||
tree_x: first pytree to use. | ||
tree_y: second pytree to use. | ||
Returns: | ||
inner product between ``tree_x`` and ``tree_y``, a scalar value. | ||
>>> optax.tree_utils.tree_vdot( | ||
>>> {a: jnp.array([1, 2]), b: jnp.array([1, 2])}, | ||
>>> {a: jnp.array([-1, -1]), b: jnp.array([1, 1])}, | ||
>>> ) | ||
0.0 | ||
Implementation detail: we upcast the values to the highest precision to avoid | ||
numerical issues. | ||
""" | ||
vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y) | ||
return jtu.tree_reduce(operator.add, vdots) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Tests for optax.tree_utils.""" | ||
|
||
from absl.testing import absltest | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from optax import tree_utils as tu | ||
|
||
|
||
class TreeUtilsTest(absltest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
rng = np.random.RandomState(0) | ||
|
||
self.tree_a = (rng.randn(20, 10) + 1j * rng.randn(20, 10), rng.randn(20)) | ||
self.tree_b = (rng.randn(20, 10), rng.randn(20)) | ||
|
||
self.tree_a_dict = (1.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0) | ||
self.tree_b_dict = (1.0, {'k1': 2.0, 'k2': (3.0, 4.0)}, 5.0) | ||
|
||
self.array_a = rng.randn(20) + 1j * rng.randn(20) | ||
self.array_b = rng.randn(20) | ||
|
||
def test_tree_vdot(self): | ||
expected = jnp.vdot(self.array_a, self.array_b) | ||
got = tu.tree_vdot(self.array_a, self.array_b) | ||
np.testing.assert_allclose(expected, got) | ||
|
||
expected = 15.0 | ||
got = tu.tree_vdot(self.tree_a_dict, self.tree_b_dict) | ||
np.testing.assert_allclose(expected, got) | ||
|
||
expected = (jnp.vdot(self.tree_a[0], self.tree_b[0]) + | ||
jnp.vdot(self.tree_a[1], self.tree_b[1])) | ||
got = tu.tree_vdot(self.tree_a, self.tree_b) | ||
np.testing.assert_allclose(expected, got) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |