diff --git a/tests/test_mixins.py b/tests/test_mixins.py new file mode 100644 index 0000000..143e604 --- /dev/null +++ b/tests/test_mixins.py @@ -0,0 +1,40 @@ +import typing as tp +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import treeo as to + + +class TestMixins: + def test_apply(self): + @dataclass + class SomeTree(to.Tree, to.Apply): + x: int = to.node() + + tree = SomeTree(x=1) + + def f(tree: SomeTree): + tree.x = 2 + + tree2 = tree.apply(f) + + assert tree.x == 1 + assert tree2.x == 2 + + def test_apply_inplace(self): + @dataclass + class SomeTree(to.Tree, to.Apply): + x: int = to.node() + + tree = SomeTree(x=1) + + def f(tree: SomeTree): + tree.x = 2 + + tree.apply(f, inplace=True) + + assert tree.x == 2 diff --git a/treeo/mixins.py b/treeo/mixins.py index 5dc1cab..42e5fc4 100644 --- a/treeo/mixins.py +++ b/treeo/mixins.py @@ -219,7 +219,7 @@ def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) - Returns: A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`. """ - return api.apply(f, self, *rest, inplace=inplace) + return tree_m.apply(f, self, *rest, inplace=inplace) class Compact: diff --git a/treeo/tree.py b/treeo/tree.py index 00af759..2453949 100644 --- a/treeo/tree.py +++ b/treeo/tree.py @@ -173,7 +173,8 @@ def __call__(cls, *args, **kwargs) -> "Tree": return obj else: obj = cls.__new__(cls) - obj = cls.construct(obj, *args, **kwargs) + with _make_mutable_single(obj): + obj = cls.construct(obj, *args, **kwargs) if _COMPACT_CONTEXT.new_subtrees is not None: _COMPACT_CONTEXT.new_subtrees.append(obj) @@ -181,27 +182,27 @@ def __call__(cls, *args, **kwargs) -> "Tree": return obj def construct(cls, obj: T, *args, **kwargs) -> T: - with _make_mutable_single(obj): - obj._field_metadata = obj._field_metadata.copy() - # set default fields - for field, default_factory in obj._factory_fields.items(): - setattr(obj, field, default_factory()) + obj._field_metadata = obj._field_metadata.copy() - for field, default_value in obj._default_field_values.items(): - setattr(obj, field, default_value) + # set default fields + for field, default_factory in obj._factory_fields.items(): + setattr(obj, field, default_factory()) - # reset context before __init__ and add obj as current tree - with _CompactContext(current_tree=obj): - obj.__init__(*args, **kwargs) + for field, default_value in obj._default_field_values.items(): + setattr(obj, field, default_value) - # auto-annotations - obj._update_local_metadata() + # reset context before __init__ and add obj as current tree + with _CompactContext(current_tree=obj): + obj.__init__(*args, **kwargs) - if _COMPACT_CONTEXT.current_tree is not None: - obj._mutable = _COMPACT_CONTEXT.current_tree._mutable + # auto-annotations + obj._update_local_metadata() - return obj + if _COMPACT_CONTEXT.current_tree is not None: + obj._mutable = _COMPACT_CONTEXT.current_tree._mutable + + return obj class Tree(metaclass=TreeMeta):