From f08636ed7bb89025f7604164c96a0dcaab58e37e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 May 2024 18:56:50 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 633773679 --- tree_math/_src/structs_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tree_math/_src/structs_test.py b/tree_math/_src/structs_test.py index 879e69c..b3af767 100644 --- a/tree_math/_src/structs_test.py +++ b/tree_math/_src/structs_test.py @@ -41,8 +41,8 @@ class StructsTest(test_util.TestCase): dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5]))) ) def testFlattenUnflatten(self, x): - leaves, structure = jax.tree_flatten(x) - y = jax.tree_unflatten(structure, leaves) + leaves, structure = jax.tree.flatten(x) + y = jax.tree.unflatten(structure, leaves) np.testing.assert_allclose(x.a, y.a) np.testing.assert_allclose(x.b, y.b)