Skip to content
This repository was archived by the owner on Feb 26, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import typing as tp
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import numpy as np
import pytest

import treeo as to


class TestMixins:
def test_apply(self):
@dataclass
class SomeTree(to.Tree, to.Apply):
x: int = to.node()

tree = SomeTree(x=1)

def f(tree: SomeTree):
tree.x = 2

tree2 = tree.apply(f)

assert tree.x == 1
assert tree2.x == 2

def test_apply_inplace(self):
@dataclass
class SomeTree(to.Tree, to.Apply):
x: int = to.node()

tree = SomeTree(x=1)

def f(tree: SomeTree):
tree.x = 2

tree.apply(f, inplace=True)

assert tree.x == 2
2 changes: 1 addition & 1 deletion treeo/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
return api.apply(f, self, *rest, inplace=inplace)
return tree_m.apply(f, self, *rest, inplace=inplace)


class Compact:
Expand Down
33 changes: 17 additions & 16 deletions treeo/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,35 +173,36 @@ def __call__(cls, *args, **kwargs) -> "Tree":
return obj
else:
obj = cls.__new__(cls)
obj = cls.construct(obj, *args, **kwargs)
with _make_mutable_single(obj):
obj = cls.construct(obj, *args, **kwargs)

if _COMPACT_CONTEXT.new_subtrees is not None:
_COMPACT_CONTEXT.new_subtrees.append(obj)

return obj

def construct(cls, obj: T, *args, **kwargs) -> T:
with _make_mutable_single(obj):
obj._field_metadata = obj._field_metadata.copy()

# set default fields
for field, default_factory in obj._factory_fields.items():
setattr(obj, field, default_factory())
obj._field_metadata = obj._field_metadata.copy()

for field, default_value in obj._default_field_values.items():
setattr(obj, field, default_value)
# set default fields
for field, default_factory in obj._factory_fields.items():
setattr(obj, field, default_factory())

# reset context before __init__ and add obj as current tree
with _CompactContext(current_tree=obj):
obj.__init__(*args, **kwargs)
for field, default_value in obj._default_field_values.items():
setattr(obj, field, default_value)

# auto-annotations
obj._update_local_metadata()
# reset context before __init__ and add obj as current tree
with _CompactContext(current_tree=obj):
obj.__init__(*args, **kwargs)

if _COMPACT_CONTEXT.current_tree is not None:
obj._mutable = _COMPACT_CONTEXT.current_tree._mutable
# auto-annotations
obj._update_local_metadata()

return obj
if _COMPACT_CONTEXT.current_tree is not None:
obj._mutable = _COMPACT_CONTEXT.current_tree._mutable

return obj


class Tree(metaclass=TreeMeta):
Expand Down