From 0727453f840d4b3b7033bc1faf2c1815e78b94a6 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 3 Apr 2024 17:51:37 -0700 Subject: [PATCH] [tree-math] add replace() method This convenience method is copied from flax.struct. PiperOrigin-RevId: 621690005 --- setup.py | 2 +- tree_math/__init__.py | 2 +- tree_math/_src/structs.py | 1 + tree_math/_src/structs_test.py | 6 ++++++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 65be738..b4ed0f3 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setuptools.setup( name='tree-math', description='Mathematical operations for JAX pytrees', - version='0.2.0 ', + version='0.2.1', license='Apache 2.0', author='Google LLC', author_email='noreply@google.com', diff --git a/tree_math/__init__.py b/tree_math/__init__.py index 981869e..c2e0241 100644 --- a/tree_math/__init__.py +++ b/tree_math/__init__.py @@ -24,4 +24,4 @@ from tree_math._src.vector import Vector, VectorMixin import tree_math.numpy -__version__ = '0.2.0' +__version__ = '0.2.1' diff --git a/tree_math/_src/structs.py b/tree_math/_src/structs.py index 57e135d..2a16af6 100644 --- a/tree_math/_src/structs.py +++ b/tree_math/_src/structs.py @@ -72,6 +72,7 @@ def tree_unflatten(cls, _, children): {'fields': fields, 'asdict': asdict, 'astuple': astuple, + 'replace': dataclasses.replace, 'tree_flatten': tree_flatten, 'tree_unflatten': tree_unflatten, '__module__': cls.__module__}) diff --git a/tree_math/_src/structs_test.py b/tree_math/_src/structs_test.py index 5b9670e..879e69c 100644 --- a/tree_math/_src/structs_test.py +++ b/tree_math/_src/structs_test.py @@ -109,6 +109,12 @@ def testPickle(self): restored = pickle.loads(pickle.dumps(struct)) self.assertTreeEqual(struct, restored, check_dtypes=True) + def testReplace(self): + struct = TestStruct(1, 2) + replaced = struct.replace(b=3) + expected = TestStruct(1, 3) + self.assertTreeEqual(replaced, expected, check_dtypes=True) + if __name__ == '__main__': absltest.main()