Skip to content

Commit

Permalink
[tree-math] add replace() method
Browse files Browse the repository at this point in the history
This convenience method is copied from flax.struct.

PiperOrigin-RevId: 621690005
  • Loading branch information
shoyer authored and tree-math authors committed Apr 4, 2024
1 parent 4f9cd0a commit 0727453
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
setuptools.setup(
name='tree-math',
description='Mathematical operations for JAX pytrees',
version='0.2.0 ',
version='0.2.1',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tree_math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
from tree_math._src.vector import Vector, VectorMixin
import tree_math.numpy

__version__ = '0.2.0'
__version__ = '0.2.1'
1 change: 1 addition & 0 deletions tree_math/_src/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def tree_unflatten(cls, _, children):
{'fields': fields,
'asdict': asdict,
'astuple': astuple,
'replace': dataclasses.replace,
'tree_flatten': tree_flatten,
'tree_unflatten': tree_unflatten,
'__module__': cls.__module__})
Expand Down
6 changes: 6 additions & 0 deletions tree_math/_src/structs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def testPickle(self):
restored = pickle.loads(pickle.dumps(struct))
self.assertTreeEqual(struct, restored, check_dtypes=True)

def testReplace(self):
struct = TestStruct(1, 2)
replaced = struct.replace(b=3)
expected = TestStruct(1, 3)
self.assertTreeEqual(replaced, expected, check_dtypes=True)


if __name__ == '__main__':
absltest.main()

0 comments on commit 0727453

Please sign in to comment.